def __init__(self, remote: Node, privkey: datatypes.PrivateKey, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, aes_secret: bytes, mac_secret: bytes, egress_mac: PreImage, ingress_mac: PreImage, chaindb: ChainDB, network_id: int, ) -> None: self._finished = asyncio.Event() self.remote = remote self.privkey = privkey self.reader = reader self.writer = writer self.base_protocol = P2PProtocol(self) self.chaindb = chaindb self.network_id = network_id self.sub_proto_msg_queue = asyncio.Queue() # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]] # noqa: E501 self.cancel_token = CancelToken('Peer') self.egress_mac = egress_mac self.ingress_mac = ingress_mac # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32 iv = b"\x00" * 16 aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend()) self.aes_enc = aes_cipher.encryptor() self.aes_dec = aes_cipher.decryptor() mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend()) self.mac_enc = mac_cipher.encryptor().update
def __init__(self, state_db: BaseDB, root_hash: bytes, peer_pool: PeerPool) -> None: self.peer_pool = peer_pool self.peer_pool.subscribe(self) self.root_hash = root_hash self.scheduler = StateSync(root_hash, state_db, self.logger) self._running_peers = set() # type: Set[ETHPeer] self.cancel_token = CancelToken('StateDownloader')
def __init__(self, chaindb: ChainDB, peer_pool: PeerPool) -> None: super(LightChain, self).__init__(chaindb) self.peer_pool = peer_pool self.peer_pool.subscribe(self) self._announcement_queue = asyncio.Queue( ) # type: asyncio.Queue[Tuple[LESPeer, les.HeadInfo]] # noqa: E501 self._last_processed_announcements = { } # type: Dict[LESPeer, les.HeadInfo] self.cancel_token = CancelToken('LightChain') self._running_peers = set() # type: Set[LESPeer]
def test_token_chain_trigger_chain(): token = CancelToken('token') token2 = CancelToken('token2') token3 = CancelToken('token3') chain = token.chain(token2).chain(token3) assert not chain.triggered chain.trigger() assert chain.triggered assert chain.triggered_token == chain assert not token.triggered assert not token2.triggered assert not token3.triggered
def __init__(self, peer_class: Type[BasePeer], chaindb: ChainDB, network_id: int, privkey: datatypes.PrivateKey, ) -> None: self.peer_class = peer_class self.chaindb = chaindb self.network_id = network_id self.privkey = privkey self.connected_nodes = {} # type: Dict[Node, BasePeer] self.cancel_token = CancelToken('PeerPool') self._subscribers = [] # type: List[PeerPoolSubscriber]
async def test_wait_cancel_pending_tasks_on_cancellation(event_loop): """Test that cancelling a pending CancelToken.wait() coroutine doesn't leave .wait() coroutines for any chained tokens behind. """ token = CancelToken('token').chain(CancelToken('token2')).chain( CancelToken('token3')) token_wait_coroutine = token.wait() done, pending = await asyncio.wait([token_wait_coroutine], timeout=0.1) assert len(done) == 0 assert len(pending) == 1 pending_task = pending.pop() assert pending_task._coro == token_wait_coroutine pending_task.cancel() await assert_only_current_task_not_done()
async def test_wait_cancel_pending_tasks_on_completion(event_loop): token = CancelToken('token') token2 = CancelToken('token2') chain = token.chain(token2) token2.trigger() await chain.wait() await assert_only_current_task_not_done()
class StateDownloader(PeerPoolSubscriber): logger = logging.getLogger("evm.p2p.state.StateDownloader") _pending_nodes = {} # type: Dict[Any, float] _total_processed_nodes = 0 _report_interval = 10 # Number of seconds between progress reports. # TODO: Experiment with different timeout/max_pending values to find the combination that # yields the best results. # FIXME: Should use the # of peers times MAX_STATE_FETCH here _max_pending = 5 * MAX_STATE_FETCH _reply_timeout = 10 # seconds # For simplicity/readability we use 0 here to force a report on the first iteration of the # loop. _last_report_time = 0 def __init__(self, state_db: BaseDB, root_hash: bytes, peer_pool: PeerPool) -> None: self.peer_pool = peer_pool self.peer_pool.subscribe(self) self.root_hash = root_hash self.scheduler = StateSync(root_hash, state_db, self.logger) self._running_peers = set() # type: Set[ETHPeer] self.cancel_token = CancelToken('StateDownloader') def register_peer(self, peer: BasePeer) -> None: asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer))) async def handle_peer(self, peer: ETHPeer) -> None: """Handle the lifecycle of the given peer.""" self._running_peers.add(peer) try: await self._handle_peer(peer) finally: self._running_peers.remove(peer) async def _handle_peer(self, peer: ETHPeer) -> None: while True: try: cmd, msg = await peer.read_sub_proto_msg(self.cancel_token) except OperationCancelled: # Either our cancel token or the peer's has been triggered, so break out of the # loop. break if isinstance(cmd, eth.NodeData): self.logger.debug("Processing NodeData with %d entries", len(msg)) for node in msg: self._total_processed_nodes += 1 node_key = keccak(node) try: self.scheduler.process([(node_key, node)]) except SyncRequestAlreadyProcessed: # This means we received a node more than once, which can happen when we # retry after a timeout. pass # A node may be received more than once, so pop() with a default value. self._pending_nodes.pop(node_key, None) else: # It'd be very convenient if we could ignore everything that is not a NodeData # when doing a StateSync, but need to double check because peers may consider that # "Bad Form" and disconnect from us. self.logger.debug("Ignoring %s(%s) while doing a StateSync", cmd, msg) # FIXME: Need a better criteria to select peers here. async def get_random_peer(self) -> ETHPeer: while not self.peer_pool.peers: self.logger.debug("No connected peers, sleeping a bit") await asyncio.sleep(0.5) peer = random.choice(self.peer_pool.peers) return cast(ETHPeer, peer) async def stop(self): self.cancel_token.trigger() self.peer_pool.unsubscribe(self) while self._running_peers: self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers)) await asyncio.sleep(0.1) async def request_next_batch(self): requests = self.scheduler.next_batch(MAX_STATE_FETCH) if not requests: # Although our run() loop frequently yields control to let our msg handler process # received nodes (scheduling new requests), there may be cases when the pending nodes # take a while to arrive thus causing the scheduler to run out of new requests for a # while. self.logger.debug( "Scheduler queue is empty, not requesting any nodes") return self.logger.debug("Requesting %d trie nodes", len(requests)) await self.request_nodes([request.node_key for request in requests]) async def request_nodes(self, node_keys: List[bytes]): peer = await self.get_random_peer() now = time.time() for node_key in node_keys: self._pending_nodes[node_key] = now peer.sub_proto.send_get_node_data(node_keys) async def retry_timedout(self): timed_out = [] now = time.time() for node_key, req_time in list(self._pending_nodes.items()): if now - req_time > self._reply_timeout: timed_out.append(node_key) if not timed_out: return self.logger.debug("Re-requesting %d trie nodes", len(timed_out)) await self.request_nodes(timed_out) async def run(self): self.logger.info("Starting state sync for root hash %s", encode_hex(self.root_hash)) while self.scheduler.has_pending_requests and not self.cancel_token.triggered: # Request new nodes if we haven't reached the limit of pending nodes. if len(self._pending_nodes) < self._max_pending: await self.request_next_batch() # Retry pending nodes that timed out. if self._pending_nodes: await self.retry_timedout() if len(self._pending_nodes) > self._max_pending: # Slow down if we've reached the limit of pending nodes. self.logger.debug( "Pending trie nodes limit reached, sleeping a bit") await asyncio.sleep(0.3) else: # Yield control to ensure the Peer's msg_handler callback is called to process any # nodes we may have received already. Otherwise we spin too fast and don't process # received nodes often enough. await asyncio.sleep(0) self._maybe_report_progress() self.logger.info("Finished state sync with root hash %s", encode_hex(self.root_hash)) def _maybe_report_progress(self): if (time.time() - self._last_report_time) >= self._report_interval: self._last_report_time = time.time() self.logger.info("Nodes processed: %d", self._total_processed_nodes) self.logger.info("Nodes requested but not received yet: %d", len(self._pending_nodes)) self.logger.info("Nodes scheduled but not requested yet: %d", len(self.scheduler.requests))
async def test_wait_with_token(event_loop): fut = asyncio.Future() event_loop.call_soon(functools.partial(fut.set_result, 'result')) result = await wait_with_token(fut, CancelToken('token'), timeout=1) assert result == 'result' await assert_only_current_task_not_done()
def test_token_single(): token = CancelToken('token') assert not token.triggered token.trigger() assert token.triggered assert token.triggered_token == token
async def test_wait_with_token_operation_cancelled(event_loop): token = CancelToken('token') token.trigger() with pytest.raises(OperationCancelled): await wait_with_token(asyncio.sleep(0.02), token) await assert_only_current_task_not_done()
async def test_wait_with_token_timeout(): with pytest.raises(TimeoutError): await wait_with_token(asyncio.sleep(0.02), CancelToken('token'), timeout=0.01) await assert_only_current_task_not_done()
async def test_wait_with_token_future_exception(event_loop): fut = asyncio.Future() event_loop.call_soon(functools.partial(fut.set_exception, Exception())) with pytest.raises(Exception): await wait_with_token(fut, CancelToken('token'), timeout=1) await assert_only_current_task_not_done()
class LightChain(Chain, PeerPoolSubscriber): logger = logging.getLogger("evm.p2p.lightchain.LightChain") max_consecutive_timeouts = 5 def __init__(self, chaindb: ChainDB, peer_pool: PeerPool) -> None: super(LightChain, self).__init__(chaindb) self.peer_pool = peer_pool self.peer_pool.subscribe(self) self._announcement_queue = asyncio.Queue( ) # type: asyncio.Queue[Tuple[LESPeer, les.HeadInfo]] # noqa: E501 self._last_processed_announcements = { } # type: Dict[LESPeer, les.HeadInfo] self.cancel_token = CancelToken('LightChain') self._running_peers = set() # type: Set[LESPeer] @classmethod def from_genesis_header(cls, chaindb, genesis_header, peer_pool): chaindb.persist_header_to_db(genesis_header) return cls(chaindb, peer_pool) def register_peer(self, peer: BasePeer) -> None: asyncio.ensure_future(self.handle_peer(cast(LESPeer, peer))) async def handle_peer(self, peer: LESPeer) -> None: """Handle the lifecycle of the given peer. Returns when the peer is finished or when the LightChain is asked to stop. """ self._running_peers.add(peer) try: await self._handle_peer(peer) finally: self._running_peers.remove(peer) async def _handle_peer(self, peer: LESPeer) -> None: self._announcement_queue.put_nowait((peer, peer.head_info)) while True: try: cmd, msg = await peer.read_sub_proto_msg(self.cancel_token) except OperationCancelled: # Either the peer disconnected or our cancel_token has been triggered, so break # out of the loop to stop attempting to sync with this peer. break # We currently implement only the LES commands for retrieving data (apart from # Announce), and those should always come as a response to requests we make so will be # handled by LESPeer.handle_sub_proto_msg(). if isinstance(cmd, les.Announce): peer.head_info = cmd.as_head_info(msg) self._announcement_queue.put_nowait((peer, peer.head_info)) else: raise UnexpectedMessage("Unexpected msg from {}: {}".format( peer, msg)) await self.drop_peer(peer) self.logger.debug("%s finished", peer) async def drop_peer(self, peer: LESPeer) -> None: self._last_processed_announcements.pop(peer, None) await peer.stop() async def wait_until_finished(self): start_at = time.time() # Wait at most 5 seconds for pending peers to finish. while time.time() < start_at + 5: if not self._running_peers: break self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers)) await asyncio.sleep(0.1) else: self.logger.info( "Waited too long for peers to finish, exiting anyway") async def get_best_peer(self) -> LESPeer: """ Return the peer with the highest announced block height. """ while not self.peer_pool.peers: self.logger.debug("No connected peers, sleeping a bit") await asyncio.sleep(0.5) def peer_block_height(peer: LESPeer): last_announced = self._last_processed_announcements.get(peer) if last_announced is None: return -1 return last_announced.block_number # TODO: Should pick a random one in case there are multiple peers with the same block # height. return max([cast(LESPeer, peer) for peer in self.peer_pool.peers], key=peer_block_height) async def wait_for_announcement(self) -> Tuple[LESPeer, les.HeadInfo]: """Wait for a new announcement from any of our connected peers. Returns a tuple containing the LESPeer on which the announcement was received and the announcement info. Raises OperationCancelled when LightChain.stop() has been called. """ # Wait for either a new announcement or our cancel_token to be triggered. return await wait_with_token(self._announcement_queue.get(), self.cancel_token) async def run(self) -> None: """Run the LightChain, ensuring headers are in sync with connected peers. If .stop() is called, we'll disconnect from all peers and return. """ self.logger.info("Running LightChain...") while True: try: peer, head_info = await self.wait_for_announcement() except OperationCancelled: self.logger.debug("Asked to stop, breaking out of run() loop") break try: await self.process_announcement(peer, head_info) self._last_processed_announcements[peer] = head_info except OperationCancelled: self.logger.debug("Asked to stop, breaking out of run() loop") break except LESAnnouncementProcessingError as e: self.logger.warning(repr(e)) await self.drop_peer(peer) except Exception as e: self.logger.error( "Unexpected error when processing announcement: %s", repr(e)) await self.drop_peer(peer) async def fetch_headers(self, start_block: int, peer: LESPeer) -> List[BlockHeader]: for i in range(self.max_consecutive_timeouts): try: return await peer.fetch_headers_starting_at( start_block, self.cancel_token) except asyncio.TimeoutError: self.logger.info( "Timeout when fetching headers from %s (attempt %d of %d)", peer, i + 1, self.max_consecutive_timeouts) # TODO: Figure out what's a good value to use here. await asyncio.sleep(0.5) raise TooManyTimeouts() async def get_sync_start_block(self, peer: LESPeer, head_info: les.HeadInfo) -> int: chain_head = await self.chaindb.coro_get_canonical_head() last_peer_announcement = self._last_processed_announcements.get(peer) if chain_head.block_number == GENESIS_BLOCK_NUMBER: start_block = GENESIS_BLOCK_NUMBER elif last_peer_announcement is None: # It's the first time we hear from this peer, need to figure out which headers to # get from it. We can't simply fetch headers starting from our current head # number because there may have been a chain reorg, so we fetch some headers prior # to our head from the peer, and insert any missing ones in our DB, essentially # making our canonical chain identical to the peer's up to # chain_head.block_number. oldest_ancestor_to_consider = max( 0, chain_head.block_number - peer.max_headers_fetch + 1) try: headers = await self.fetch_headers(oldest_ancestor_to_consider, peer) except EmptyGetBlockHeadersReply: raise LESAnnouncementProcessingError( "No common ancestors found between us and {}".format(peer)) except TooManyTimeouts: raise LESAnnouncementProcessingError( "Too many timeouts when fetching headers from {}".format( peer)) for header in headers: await self.chaindb.coro_persist_header_to_db(header) start_block = chain_head.block_number else: start_block = last_peer_announcement.block_number - head_info.reorg_depth return start_block # TODO: Distribute requests among our peers, ensuring the selected peer has the info we want # and respecting the flow control rules. async def process_announcement(self, peer: LESPeer, head_info: les.HeadInfo) -> None: if await self.chaindb.coro_header_exists(head_info.block_hash): self.logger.debug( "Skipping processing of %s from %s as head has already been fetched", head_info, peer) return start_block = await self.get_sync_start_block(peer, head_info) while start_block < head_info.block_number: try: # We should use "start_block + 1" here, but we always re-fetch the last synced # block to work around https://github.com/ethereum/go-ethereum/issues/15447 batch = await self.fetch_headers(start_block, peer) except TooManyTimeouts: raise LESAnnouncementProcessingError( "Too many timeouts when fetching headers from {}".format( peer)) for header in batch: await self.chaindb.coro_persist_header_to_db(header) start_block = header.block_number self.logger.info("synced headers up to #%s", start_block) async def stop(self): self.logger.info("Stopping LightChain...") self.cancel_token.trigger() self.logger.debug("Waiting for all pending tasks to finish...") await self.wait_until_finished() self.logger.debug("LightChain finished") async def get_canonical_block_by_number(self, block_number: int) -> BaseBlock: """Return the block with the given number from the canonical chain. Raises BlockNotFound if it is not found. """ try: block_hash = await self.chaindb.coro_lookup_block_hash(block_number ) except KeyError: raise BlockNotFound( "No block with number {} found on local chain".format( block_number)) return await self.get_block_by_hash(block_hash) @alru_cache(maxsize=1024, cache_exceptions=False) async def get_block_by_hash(self, block_hash: bytes) -> BaseBlock: peer = await self.get_best_peer() try: header = await self.chaindb.coro_get_block_header_by_hash( block_hash) except BlockNotFound: self.logger.debug("Fetching header %s from %s", encode_hex(block_hash), peer) header = await peer.get_block_header_by_hash( block_hash, self.cancel_token) self.logger.debug("Fetching block %s from %s", encode_hex(block_hash), peer) body = await peer.get_block_by_hash(block_hash, self.cancel_token) block_class = self.get_vm_class_for_block_number( header.block_number).get_block_class() transactions = [ block_class.transaction_class.from_base_transaction(tx) for tx in body.transactions ] return block_class( header=header, transactions=transactions, uncles=body.uncles, ) @alru_cache(maxsize=1024, cache_exceptions=False) async def get_receipts(self, block_hash: bytes) -> List[Receipt]: peer = await self.get_best_peer() self.logger.debug("Fetching %s receipts from %s", encode_hex(block_hash), peer) return await peer.get_receipts(block_hash, self.cancel_token) @alru_cache(maxsize=1024, cache_exceptions=False) async def get_account(self, block_hash: bytes, address: bytes) -> Account: peer = await self.get_best_peer() return await peer.get_account(block_hash, address, self.cancel_token) @alru_cache(maxsize=1024, cache_exceptions=False) async def get_contract_code(self, block_hash: bytes, key: bytes) -> bytes: peer = await self.get_best_peer() return await peer.get_contract_code(block_hash, key, self.cancel_token)
class PeerPool: """PeerPool attempts to keep connections to at least min_peers on the given network.""" logger = logging.getLogger("evm.p2p.peer.PeerPool") min_peers = 2 _connect_loop_sleep = 2 def __init__(self, peer_class: Type[BasePeer], chaindb: ChainDB, network_id: int, privkey: datatypes.PrivateKey, ) -> None: self.peer_class = peer_class self.chaindb = chaindb self.network_id = network_id self.privkey = privkey self.connected_nodes = {} # type: Dict[Node, BasePeer] self.cancel_token = CancelToken('PeerPool') self._subscribers = [] # type: List[PeerPoolSubscriber] def subscribe(self, subscriber: PeerPoolSubscriber) -> None: self._subscribers.append(subscriber) for peer in self.connected_nodes.values(): subscriber.register_peer(peer) def unsubscribe(self, subscriber: PeerPoolSubscriber) -> None: if subscriber in self._subscribers: self._subscribers.remove(subscriber) async def get_nodes_to_connect(self) -> List[Node]: # TODO: This should use the Discovery service to lookup nodes to connect to, but our # current implementation only supports v4 and with that it takes an insane amount of time # to find any LES nodes with the same network ID as us, so for now we hard-code some nodes # that seem to have a good uptime. from evm.chains.ropsten import RopstenChain from evm.chains.mainnet import MainnetChain if self.network_id == MainnetChain.network_id: return [ Node( keys.PublicKey(decode_hex("1118980bf48b0a3640bdba04e0fe78b1add18e1cd99bf22d53daac1fd9972ad650df52176e7c7d89d1114cfef2bc23a2959aa54998a46afcf7d91809f0855082")), # noqa: E501 Address("52.74.57.123", 30303, 30303)), Node( keys.PublicKey(decode_hex("78de8a0916848093c73790ead81d1928bec737d565119932b98c6b100d944b7a95e94f847f689fc723399d2e31129d182f7ef3863f2b4c820abbf3ab2722344d")), # noqa: E501 Address("191.235.84.50", 30303, 30303)), Node( keys.PublicKey(decode_hex("ddd81193df80128880232fc1deb45f72746019839589eeb642d3d44efbb8b2dda2c1a46a348349964a6066f8afb016eb2a8c0f3c66f32fadf4370a236a4b5286")), # noqa: E501 Address("52.231.202.145", 30303, 30303)), Node( keys.PublicKey(decode_hex("3f1d12044546b76342d59d4a05532c14b85aa669704bfe1f864fe079415aa2c02d743e03218e57a33fb94523adb54032871a6c51b2cc5514cb7c7e35b3ed0a99")), # noqa: E501 Address("13.93.211.84", 30303, 30303)), ] elif self.network_id == RopstenChain.network_id: return [ Node( keys.PublicKey(decode_hex("88c2b24429a6f7683fbfd06874ae3f1e7c8b4a5ffb846e77c705ba02e2543789d66fc032b6606a8d8888eb6239a2abe5897ce83f78dcdcfcb027d6ea69aa6fe9")), # noqa: E501 Address("163.172.157.61", 30303, 30303)), Node( keys.PublicKey(decode_hex("a1ef9ba5550d5fac27f7cbd4e8d20a643ad75596f307c91cd6e7f85b548b8a6bf215cca436d6ee436d6135f9fe51398f8dd4c0bd6c6a0c332ccb41880f33ec12")), # noqa: E501 Address("51.15.218.125", 30303, 30303)), Node( keys.PublicKey(decode_hex("e80276aabb7682a4a659f4341c1199de79d91a2e500a6ee9bed16ed4ce927ba8d32ba5dea357739ffdf2c5bcc848d3064bb6f149f0b4249c1f7e53f8bf02bfc8")), # noqa: E501 Address("51.15.39.57", 30303, 30303)), Node( keys.PublicKey(decode_hex("584c0db89b00719e9e7b1b5c32a4a8942f379f4d5d66bb69f9c7fa97fa42f64974e7b057b35eb5a63fd7973af063f9a1d32d8c60dbb4854c64cb8ab385470258")), # noqa: E501 Address("51.15.35.2", 30303, 30303)), Node( keys.PublicKey(decode_hex("d40871fc3e11b2649700978e06acd68a24af54e603d4333faecb70926ca7df93baa0b7bf4e927fcad9a7c1c07f9b325b22f6d1730e728314d0e4e6523e5cebc2")), # noqa: E501 Address("51.15.132.235", 30303, 30303)), Node( keys.PublicKey(decode_hex("482484b9198530ee2e00db89791823244ca41dcd372242e2e1297dd06f6d8dd357603960c5ad9cc8dc15fcdf0e4edd06b7ad7db590e67a0b54f798c26581ebd7")), # noqa: E501 Address("51.15.75.138", 30303, 30303)), ] else: raise ValueError("Unknown network_id: {}".format(self.network_id)) async def run(self): self.logger.info("Running PeerPool...") while not self.cancel_token.triggered: try: await self.maybe_connect_to_more_peers() except: # noqa: E722 # Most unexpected errors should be transient, so we log and restart from scratch. self.logger.error("Unexpected error (%s), restarting", traceback.format_exc()) await self.stop_all_peers() # Wait self._connect_loop_sleep seconds, unless we're asked to stop. await asyncio.wait([self.cancel_token.wait()], timeout=self._connect_loop_sleep) async def stop_all_peers(self): self.logger.info("Stopping all peers ...") await asyncio.gather( *[peer.stop() for peer in self.connected_nodes.values()]) async def stop(self): self.cancel_token.trigger() await self.stop_all_peers() async def connect(self, remote: Node) -> BasePeer: """ Connect to the given remote and return a Peer instance when successful. Returns None if the remote is unreachable, times out or is useless. """ if remote in self.connected_nodes: self.logger.debug("Skipping %s; already connected to it", remote) return None expected_exceptions = ( UnreachablePeer, asyncio.TimeoutError, PeerConnectionLost, HandshakeFailure) try: self.logger.info("Connecting to %s...", remote) # TODO: Use asyncio.wait() and our cancel_token here to cancel in case the token is # triggered. peer = await asyncio.wait_for( handshake(remote, self.privkey, self.peer_class, self.chaindb, self.network_id), HANDSHAKE_TIMEOUT) return peer except expected_exceptions as e: self.logger.info("Could not complete handshake with %s: %s", remote, repr(e)) except Exception: self.logger.warning("Unexpected error during auth/p2p handshake with %s: %s", remote, traceback.format_exc()) return None async def maybe_connect_to_more_peers(self): """Connect to more peers if we're not yet connected to at least self.min_peers.""" if len(self.connected_nodes) >= self.min_peers: self.logger.debug( "Already connected to %s peers: %s; sleeping", len(self.connected_nodes), [remote for remote in self.connected_nodes]) return for node in await self.get_nodes_to_connect(): # TODO: Consider changing connect() to raise an exception instead of returning None, # as discussed in # https://github.com/pipermerriam/py-evm/pull/139#discussion_r152067425 peer = await self.connect(node) if peer is not None: self.logger.info("Successfully connected to %s", peer) asyncio.ensure_future(peer.run(finished_callback=self._peer_finished)) self.connected_nodes[peer.remote] = peer for subscriber in self._subscribers: subscriber.register_peer(peer) def _peer_finished(self, peer: BasePeer) -> None: """Remove the given peer from our list of connected nodes. This is passed as a callback to be called when a peer finishes. """ if peer.remote in self.connected_nodes: self.connected_nodes.pop(peer.remote) @property def peers(self) -> List[BasePeer]: return list(self.connected_nodes.values())
class BasePeer: logger = logging.getLogger("evm.p2p.peer.Peer") conn_idle_timeout = CONN_IDLE_TIMEOUT reply_timeout = REPLY_TIMEOUT # Must be defined in subclasses. All items here must be Protocol classes representing # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc). _supported_sub_protocols = [] # type: List[Type[protocol.Protocol]] # FIXME: Must be configurable. listen_port = 30303 # Will be set upon the successful completion of a P2P handshake. sub_proto = None # type: protocol.Protocol def __init__(self, remote: Node, privkey: datatypes.PrivateKey, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, aes_secret: bytes, mac_secret: bytes, egress_mac: PreImage, ingress_mac: PreImage, chaindb: ChainDB, network_id: int, ) -> None: self._finished = asyncio.Event() self.remote = remote self.privkey = privkey self.reader = reader self.writer = writer self.base_protocol = P2PProtocol(self) self.chaindb = chaindb self.network_id = network_id self.sub_proto_msg_queue = asyncio.Queue() # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]] # noqa: E501 self.cancel_token = CancelToken('Peer') self.egress_mac = egress_mac self.ingress_mac = ingress_mac # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32 iv = b"\x00" * 16 aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend()) self.aes_enc = aes_cipher.encryptor() self.aes_dec = aes_cipher.decryptor() mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend()) self.mac_enc = mac_cipher.encryptor().update def send_sub_proto_handshake(self): raise NotImplementedError() def process_sub_proto_handshake( self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: raise NotImplementedError() async def do_sub_proto_handshake(self): """Perform the handshake for the sub-protocol agreed with the remote peer. Raises HandshakeFailure if the handshake is not successful. """ self.send_sub_proto_handshake() cmd, msg = await self.read_msg() if isinstance(cmd, Disconnect): # Peers sometimes send a disconnect msg before they send the sub-proto handshake. raise HandshakeFailure( "{} disconnected before completing sub-proto handshake: {}".format( self, msg['reason_name'])) self.process_sub_proto_handshake(cmd, msg) self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote) async def do_p2p_handshake(self): """Perform the handshake for the P2P base protocol. Raises HandshakeFailure if the handshake is not successful. """ self.base_protocol.send_handshake() cmd, msg = await self.read_msg() if isinstance(cmd, Disconnect): # Peers sometimes send a disconnect msg before they send the initial P2P handshake. raise HandshakeFailure("{} disconnected before completing handshake: {}".format( self, msg['reason_name'])) self.process_p2p_handshake(cmd, msg) async def read_sub_proto_msg( self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]: """Read the next sub-protocol message from the queue. Raises OperationCancelled if the peer has been disconnected. """ combined_token = self.cancel_token.chain(cancel_token) return await wait_with_token(self.sub_proto_msg_queue.get(), combined_token) @property def genesis(self) -> BlockHeader: genesis_hash = self.chaindb.lookup_block_hash(GENESIS_BLOCK_NUMBER) return self.chaindb.get_block_header_by_hash(genesis_hash) @property def _local_chain_info(self) -> 'ChainInfo': genesis = self.genesis head = self.chaindb.get_canonical_head() total_difficulty = self.chaindb.get_score(head.hash) return ChainInfo( block_number=head.block_number, block_hash=head.hash, total_difficulty=total_difficulty, genesis_hash=genesis.hash, ) @property def capabilities(self) -> List[Tuple[bytes, int]]: return [(klass.name, klass.version) for klass in self._supported_sub_protocols] def get_protocol_command_for(self, msg: bytes) -> protocol.Command: """Return the Command corresponding to the cmd_id encoded in the given msg.""" cmd_id = get_devp2p_cmd_id(msg) self.logger.debug("Got msg with cmd_id: %s", cmd_id) if cmd_id < self.base_protocol.cmd_length: proto = self.base_protocol elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length: proto = self.sub_proto # type: ignore else: raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id)) return proto.cmd_by_id[cmd_id] async def read(self, n: int) -> bytes: self.logger.debug("Waiting for %s bytes from %s", n, self.remote) try: return await wait_with_token( self.reader.readexactly(n), self.cancel_token, timeout=self.conn_idle_timeout) except (asyncio.IncompleteReadError, ConnectionResetError): raise PeerConnectionLost("EOF reading from stream") async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None: try: await self.read_loop() except OperationCancelled as e: self.logger.debug("Peer finished: %s", e) except Exception: self.logger.error( "Unexpected error when handling remote msg: %s", traceback.format_exc()) finally: self._finished.set() if finished_callback is not None: finished_callback(self) def close(self): """Close this peer's reader/writer streams. This will cause the peer to stop in case it is running. If the streams have already been closed, do nothing. """ if self.reader.at_eof(): return self.reader.feed_eof() self.writer.close() async def stop(self): """Disconnect from the remote and flag this peer as finished. If the peer is already flagged as finished, do nothing. """ if self._finished.is_set(): return self.cancel_token.trigger() self.close() await self._finished.wait() self.logger.debug("Stopped %s", self) async def read_loop(self): while True: try: cmd, msg = await self.read_msg() except (PeerConnectionLost, asyncio.TimeoutError) as e: self.logger.info( "%s stopped responding (%s), disconnecting", self.remote, repr(e)) return self.process_msg(cmd, msg) async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]: header_data = await self.read(HEADER_LEN + MAC_LEN) header = self.decrypt_header(header_data) frame_size = self.get_frame_size(header) # The frame_size specified in the header does not include the padding to 16-byte boundary, # so need to do this here to ensure we read all the frame's data. read_size = roundup_16(frame_size) frame_data = await self.read(read_size + MAC_LEN) msg = self.decrypt_body(frame_data, frame_size) cmd = self.get_protocol_command_for(msg) decoded_msg = cmd.decode(msg) self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg) return cmd, decoded_msg def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType): """Handle the base protocol (P2P) messages.""" if isinstance(cmd, Disconnect): msg = cast(Dict[str, Any], msg) self.logger.debug( "%s disconnected; reason given: %s", self, msg['reason_name']) self.close() elif isinstance(cmd, Ping): self.base_protocol.send_pong() elif isinstance(cmd, Pong): # Currently we don't do anything when we get a pong, but eventually we should # update the last time we heard from a peer in our DB (which doesn't exist yet). pass else: raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg)) def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType): self.sub_proto_msg_queue.put_nowait((cmd, msg)) def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: if isinstance(cmd.proto, P2PProtocol): self.handle_p2p_msg(cmd, msg) else: self.handle_sub_proto_msg(cmd, msg) def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: msg = cast(Dict[str, Any], msg) if not isinstance(cmd, Hello): self.disconnect(DisconnectReason.other) raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd)) remote_capabilities = msg['capabilities'] self.sub_proto = self.select_sub_protocol(remote_capabilities) if self.sub_proto is None: self.disconnect(DisconnectReason.useless_peer) raise HandshakeFailure( "No matching capabilities between us ({}) and {} ({}), disconnecting".format( self.capabilities, self.remote, remote_capabilities)) self.logger.debug( "Finished P2P handshake with %s, using sub-protocol %s", self.remote, self.sub_proto) def encrypt(self, header: bytes, frame: bytes) -> bytes: if len(header) != HEADER_LEN: raise ValueError("Unexpected header length: {}".format(len(header))) header_ciphertext = self.aes_enc.update(header) mac_secret = self.egress_mac.digest()[:HEADER_LEN] self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext)) header_mac = self.egress_mac.digest()[:HEADER_LEN] frame_ciphertext = self.aes_enc.update(frame) self.egress_mac.update(frame_ciphertext) fmac_seed = self.egress_mac.digest()[:HEADER_LEN] mac_secret = self.egress_mac.digest()[:HEADER_LEN] self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed)) frame_mac = self.egress_mac.digest()[:HEADER_LEN] return header_ciphertext + header_mac + frame_ciphertext + frame_mac def decrypt_header(self, data: bytes) -> bytes: if len(data) != HEADER_LEN + MAC_LEN: raise ValueError("Unexpected header length: {}".format(len(data))) header_ciphertext = data[:HEADER_LEN] header_mac = data[HEADER_LEN:] mac_secret = self.ingress_mac.digest()[:HEADER_LEN] aes = self.mac_enc(mac_secret)[:HEADER_LEN] self.ingress_mac.update(sxor(aes, header_ciphertext)) expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN] if not bytes_eq(expected_header_mac, header_mac): raise AuthenticationError('Invalid header mac') return self.aes_dec.update(header_ciphertext) def decrypt_body(self, data: bytes, body_size: int) -> bytes: read_size = roundup_16(body_size) if len(data) < read_size + MAC_LEN: raise ValueError('Insufficient body length; Got {}, wanted {}'.format( len(data), (read_size + MAC_LEN))) frame_ciphertext = data[:read_size] frame_mac = data[read_size:read_size + MAC_LEN] self.ingress_mac.update(frame_ciphertext) fmac_seed = self.ingress_mac.digest()[:MAC_LEN] self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed)) expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN] if not bytes_eq(expected_frame_mac, frame_mac): raise AuthenticationError('Invalid frame mac') return self.aes_dec.update(frame_ciphertext)[:body_size] def get_frame_size(self, header: bytes) -> int: # The frame size is encoded in the header as a 3-byte int, so before we unpack we need # to prefix it with an extra byte. encoded_size = b'\x00' + header[:3] (size,) = struct.unpack(b'>I', encoded_size) return size def send(self, header: bytes, body: bytes) -> None: cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int) self.logger.debug("Sending msg with cmd_id: %s", cmd_id) self.writer.write(self.encrypt(header, body)) def disconnect(self, reason: DisconnectReason) -> None: """Send a disconnect msg to the remote node and stop this Peer. :param reason: An item from the DisconnectReason enum. """ if not isinstance(reason, DisconnectReason): self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value) raise ValueError( "Reason must be an item of DisconnectReason, got {}".format(reason)) self.base_protocol.send_disconnect(reason.value) self.close() def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]] ) -> protocol.Protocol: """Select the sub-protocol to use when talking to the remote. Find the highest version of our supported sub-protocols that is also supported by the remote and stores an instance of it (with the appropriate cmd_id offset) in self.sub_proto. """ matching_capabilities = set(self.capabilities).intersection(remote_capabilities) _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1)) offset = self.base_protocol.cmd_length for proto_class in self._supported_sub_protocols: if proto_class.version == highest_matching_version: return proto_class(self, offset) return None def __str__(self): return "{} {}".format(self.__class__.__name__, self.remote)