예제 #1
0
    def __init__(self,
                 remote: Node,
                 privkey: datatypes.PrivateKey,
                 reader: asyncio.StreamReader,
                 writer: asyncio.StreamWriter,
                 aes_secret: bytes,
                 mac_secret: bytes,
                 egress_mac: PreImage,
                 ingress_mac: PreImage,
                 chaindb: ChainDB,
                 network_id: int,
                 ) -> None:
        self._finished = asyncio.Event()
        self.remote = remote
        self.privkey = privkey
        self.reader = reader
        self.writer = writer
        self.base_protocol = P2PProtocol(self)
        self.chaindb = chaindb
        self.network_id = network_id
        self.sub_proto_msg_queue = asyncio.Queue()  # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]]  # noqa: E501
        self.cancel_token = CancelToken('Peer')

        self.egress_mac = egress_mac
        self.ingress_mac = ingress_mac
        # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
        iv = b"\x00" * 16
        aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend())
        self.aes_enc = aes_cipher.encryptor()
        self.aes_dec = aes_cipher.decryptor()
        mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend())
        self.mac_enc = mac_cipher.encryptor().update
예제 #2
0
파일: state.py 프로젝트: thedadams/py-evm
 def __init__(self, state_db: BaseDB, root_hash: bytes,
              peer_pool: PeerPool) -> None:
     self.peer_pool = peer_pool
     self.peer_pool.subscribe(self)
     self.root_hash = root_hash
     self.scheduler = StateSync(root_hash, state_db, self.logger)
     self._running_peers = set()  # type: Set[ETHPeer]
     self.cancel_token = CancelToken('StateDownloader')
예제 #3
0
 def __init__(self, chaindb: ChainDB, peer_pool: PeerPool) -> None:
     super(LightChain, self).__init__(chaindb)
     self.peer_pool = peer_pool
     self.peer_pool.subscribe(self)
     self._announcement_queue = asyncio.Queue(
     )  # type: asyncio.Queue[Tuple[LESPeer, les.HeadInfo]]  # noqa: E501
     self._last_processed_announcements = {
     }  # type: Dict[LESPeer, les.HeadInfo]
     self.cancel_token = CancelToken('LightChain')
     self._running_peers = set()  # type: Set[LESPeer]
예제 #4
0
def test_token_chain_trigger_chain():
    token = CancelToken('token')
    token2 = CancelToken('token2')
    token3 = CancelToken('token3')
    chain = token.chain(token2).chain(token3)
    assert not chain.triggered
    chain.trigger()
    assert chain.triggered
    assert chain.triggered_token == chain
    assert not token.triggered
    assert not token2.triggered
    assert not token3.triggered
예제 #5
0
 def __init__(self,
              peer_class: Type[BasePeer],
              chaindb: ChainDB,
              network_id: int,
              privkey: datatypes.PrivateKey,
              ) -> None:
     self.peer_class = peer_class
     self.chaindb = chaindb
     self.network_id = network_id
     self.privkey = privkey
     self.connected_nodes = {}  # type: Dict[Node, BasePeer]
     self.cancel_token = CancelToken('PeerPool')
     self._subscribers = []  # type: List[PeerPoolSubscriber]
예제 #6
0
async def test_wait_cancel_pending_tasks_on_cancellation(event_loop):
    """Test that cancelling a pending CancelToken.wait() coroutine doesn't leave .wait()
    coroutines for any chained tokens behind.
    """
    token = CancelToken('token').chain(CancelToken('token2')).chain(
        CancelToken('token3'))
    token_wait_coroutine = token.wait()
    done, pending = await asyncio.wait([token_wait_coroutine], timeout=0.1)
    assert len(done) == 0
    assert len(pending) == 1
    pending_task = pending.pop()
    assert pending_task._coro == token_wait_coroutine
    pending_task.cancel()
    await assert_only_current_task_not_done()
예제 #7
0
async def test_wait_cancel_pending_tasks_on_completion(event_loop):
    token = CancelToken('token')
    token2 = CancelToken('token2')
    chain = token.chain(token2)
    token2.trigger()
    await chain.wait()
    await assert_only_current_task_not_done()
예제 #8
0
파일: state.py 프로젝트: thedadams/py-evm
class StateDownloader(PeerPoolSubscriber):
    logger = logging.getLogger("evm.p2p.state.StateDownloader")
    _pending_nodes = {}  # type: Dict[Any, float]
    _total_processed_nodes = 0
    _report_interval = 10  # Number of seconds between progress reports.
    # TODO: Experiment with different timeout/max_pending values to find the combination that
    # yields the best results.
    # FIXME: Should use the # of peers times MAX_STATE_FETCH here
    _max_pending = 5 * MAX_STATE_FETCH
    _reply_timeout = 10  # seconds
    # For simplicity/readability we use 0 here to force a report on the first iteration of the
    # loop.
    _last_report_time = 0

    def __init__(self, state_db: BaseDB, root_hash: bytes,
                 peer_pool: PeerPool) -> None:
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.root_hash = root_hash
        self.scheduler = StateSync(root_hash, state_db, self.logger)
        self._running_peers = set()  # type: Set[ETHPeer]
        self.cancel_token = CancelToken('StateDownloader')

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break
            if isinstance(cmd, eth.NodeData):
                self.logger.debug("Processing NodeData with %d entries",
                                  len(msg))
                for node in msg:
                    self._total_processed_nodes += 1
                    node_key = keccak(node)
                    try:
                        self.scheduler.process([(node_key, node)])
                    except SyncRequestAlreadyProcessed:
                        # This means we received a node more than once, which can happen when we
                        # retry after a timeout.
                        pass
                    # A node may be received more than once, so pop() with a default value.
                    self._pending_nodes.pop(node_key, None)
            else:
                # It'd be very convenient if we could ignore everything that is not a NodeData
                # when doing a StateSync, but need to double check because peers may consider that
                # "Bad Form" and disconnect from us.
                self.logger.debug("Ignoring %s(%s) while doing a StateSync",
                                  cmd, msg)

    # FIXME: Need a better criteria to select peers here.
    async def get_random_peer(self) -> ETHPeer:
        while not self.peer_pool.peers:
            self.logger.debug("No connected peers, sleeping a bit")
            await asyncio.sleep(0.5)
        peer = random.choice(self.peer_pool.peers)
        return cast(ETHPeer, peer)

    async def stop(self):
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        while self._running_peers:
            self.logger.debug("Waiting for %d running peers to finish",
                              len(self._running_peers))
            await asyncio.sleep(0.1)

    async def request_next_batch(self):
        requests = self.scheduler.next_batch(MAX_STATE_FETCH)
        if not requests:
            # Although our run() loop frequently yields control to let our msg handler process
            # received nodes (scheduling new requests), there may be cases when the pending nodes
            # take a while to arrive thus causing the scheduler to run out of new requests for a
            # while.
            self.logger.debug(
                "Scheduler queue is empty, not requesting any nodes")
            return
        self.logger.debug("Requesting %d trie nodes", len(requests))
        await self.request_nodes([request.node_key for request in requests])

    async def request_nodes(self, node_keys: List[bytes]):
        peer = await self.get_random_peer()
        now = time.time()
        for node_key in node_keys:
            self._pending_nodes[node_key] = now
        peer.sub_proto.send_get_node_data(node_keys)

    async def retry_timedout(self):
        timed_out = []
        now = time.time()
        for node_key, req_time in list(self._pending_nodes.items()):
            if now - req_time > self._reply_timeout:
                timed_out.append(node_key)
        if not timed_out:
            return
        self.logger.debug("Re-requesting %d trie nodes", len(timed_out))
        await self.request_nodes(timed_out)

    async def run(self):
        self.logger.info("Starting state sync for root hash %s",
                         encode_hex(self.root_hash))
        while self.scheduler.has_pending_requests and not self.cancel_token.triggered:
            # Request new nodes if we haven't reached the limit of pending nodes.
            if len(self._pending_nodes) < self._max_pending:
                await self.request_next_batch()

            # Retry pending nodes that timed out.
            if self._pending_nodes:
                await self.retry_timedout()

            if len(self._pending_nodes) > self._max_pending:
                # Slow down if we've reached the limit of pending nodes.
                self.logger.debug(
                    "Pending trie nodes limit reached, sleeping a bit")
                await asyncio.sleep(0.3)
            else:
                # Yield control to ensure the Peer's msg_handler callback is called to process any
                # nodes we may have received already. Otherwise we spin too fast and don't process
                # received nodes often enough.
                await asyncio.sleep(0)

            self._maybe_report_progress()

        self.logger.info("Finished state sync with root hash %s",
                         encode_hex(self.root_hash))

    def _maybe_report_progress(self):
        if (time.time() - self._last_report_time) >= self._report_interval:
            self._last_report_time = time.time()
            self.logger.info("Nodes processed: %d",
                             self._total_processed_nodes)
            self.logger.info("Nodes requested but not received yet: %d",
                             len(self._pending_nodes))
            self.logger.info("Nodes scheduled but not requested yet: %d",
                             len(self.scheduler.requests))
예제 #9
0
async def test_wait_with_token(event_loop):
    fut = asyncio.Future()
    event_loop.call_soon(functools.partial(fut.set_result, 'result'))
    result = await wait_with_token(fut, CancelToken('token'), timeout=1)
    assert result == 'result'
    await assert_only_current_task_not_done()
예제 #10
0
def test_token_single():
    token = CancelToken('token')
    assert not token.triggered
    token.trigger()
    assert token.triggered
    assert token.triggered_token == token
예제 #11
0
async def test_wait_with_token_operation_cancelled(event_loop):
    token = CancelToken('token')
    token.trigger()
    with pytest.raises(OperationCancelled):
        await wait_with_token(asyncio.sleep(0.02), token)
    await assert_only_current_task_not_done()
예제 #12
0
async def test_wait_with_token_timeout():
    with pytest.raises(TimeoutError):
        await wait_with_token(asyncio.sleep(0.02),
                              CancelToken('token'),
                              timeout=0.01)
    await assert_only_current_task_not_done()
예제 #13
0
async def test_wait_with_token_future_exception(event_loop):
    fut = asyncio.Future()
    event_loop.call_soon(functools.partial(fut.set_exception, Exception()))
    with pytest.raises(Exception):
        await wait_with_token(fut, CancelToken('token'), timeout=1)
    await assert_only_current_task_not_done()
예제 #14
0
class LightChain(Chain, PeerPoolSubscriber):
    logger = logging.getLogger("evm.p2p.lightchain.LightChain")
    max_consecutive_timeouts = 5

    def __init__(self, chaindb: ChainDB, peer_pool: PeerPool) -> None:
        super(LightChain, self).__init__(chaindb)
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self._announcement_queue = asyncio.Queue(
        )  # type: asyncio.Queue[Tuple[LESPeer, les.HeadInfo]]  # noqa: E501
        self._last_processed_announcements = {
        }  # type: Dict[LESPeer, les.HeadInfo]
        self.cancel_token = CancelToken('LightChain')
        self._running_peers = set()  # type: Set[LESPeer]

    @classmethod
    def from_genesis_header(cls, chaindb, genesis_header, peer_pool):
        chaindb.persist_header_to_db(genesis_header)
        return cls(chaindb, peer_pool)

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(LESPeer, peer)))

    async def handle_peer(self, peer: LESPeer) -> None:
        """Handle the lifecycle of the given peer.

        Returns when the peer is finished or when the LightChain is asked to stop.
        """
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: LESPeer) -> None:
        self._announcement_queue.put_nowait((peer, peer.head_info))
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either the peer disconnected or our cancel_token has been triggered, so break
                # out of the loop to stop attempting to sync with this peer.
                break
            # We currently implement only the LES commands for retrieving data (apart from
            # Announce), and those should always come as a response to requests we make so will be
            # handled by LESPeer.handle_sub_proto_msg().
            if isinstance(cmd, les.Announce):
                peer.head_info = cmd.as_head_info(msg)
                self._announcement_queue.put_nowait((peer, peer.head_info))
            else:
                raise UnexpectedMessage("Unexpected msg from {}: {}".format(
                    peer, msg))

        await self.drop_peer(peer)
        self.logger.debug("%s finished", peer)

    async def drop_peer(self, peer: LESPeer) -> None:
        self._last_processed_announcements.pop(peer, None)
        await peer.stop()

    async def wait_until_finished(self):
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            self.logger.debug("Waiting for %d running peers to finish",
                              len(self._running_peers))
            await asyncio.sleep(0.1)
        else:
            self.logger.info(
                "Waited too long for peers to finish, exiting anyway")

    async def get_best_peer(self) -> LESPeer:
        """
        Return the peer with the highest announced block height.
        """
        while not self.peer_pool.peers:
            self.logger.debug("No connected peers, sleeping a bit")
            await asyncio.sleep(0.5)

        def peer_block_height(peer: LESPeer):
            last_announced = self._last_processed_announcements.get(peer)
            if last_announced is None:
                return -1
            return last_announced.block_number

        # TODO: Should pick a random one in case there are multiple peers with the same block
        # height.
        return max([cast(LESPeer, peer) for peer in self.peer_pool.peers],
                   key=peer_block_height)

    async def wait_for_announcement(self) -> Tuple[LESPeer, les.HeadInfo]:
        """Wait for a new announcement from any of our connected peers.

        Returns a tuple containing the LESPeer on which the announcement was received and the
        announcement info.

        Raises OperationCancelled when LightChain.stop() has been called.
        """
        # Wait for either a new announcement or our cancel_token to be triggered.
        return await wait_with_token(self._announcement_queue.get(),
                                     self.cancel_token)

    async def run(self) -> None:
        """Run the LightChain, ensuring headers are in sync with connected peers.

        If .stop() is called, we'll disconnect from all peers and return.
        """
        self.logger.info("Running LightChain...")
        while True:
            try:
                peer, head_info = await self.wait_for_announcement()
            except OperationCancelled:
                self.logger.debug("Asked to stop, breaking out of run() loop")
                break

            try:
                await self.process_announcement(peer, head_info)
                self._last_processed_announcements[peer] = head_info
            except OperationCancelled:
                self.logger.debug("Asked to stop, breaking out of run() loop")
                break
            except LESAnnouncementProcessingError as e:
                self.logger.warning(repr(e))
                await self.drop_peer(peer)
            except Exception as e:
                self.logger.error(
                    "Unexpected error when processing announcement: %s",
                    repr(e))
                await self.drop_peer(peer)

    async def fetch_headers(self, start_block: int,
                            peer: LESPeer) -> List[BlockHeader]:
        for i in range(self.max_consecutive_timeouts):
            try:
                return await peer.fetch_headers_starting_at(
                    start_block, self.cancel_token)
            except asyncio.TimeoutError:
                self.logger.info(
                    "Timeout when fetching headers from %s (attempt %d of %d)",
                    peer, i + 1, self.max_consecutive_timeouts)
                # TODO: Figure out what's a good value to use here.
                await asyncio.sleep(0.5)
        raise TooManyTimeouts()

    async def get_sync_start_block(self, peer: LESPeer,
                                   head_info: les.HeadInfo) -> int:
        chain_head = await self.chaindb.coro_get_canonical_head()
        last_peer_announcement = self._last_processed_announcements.get(peer)
        if chain_head.block_number == GENESIS_BLOCK_NUMBER:
            start_block = GENESIS_BLOCK_NUMBER
        elif last_peer_announcement is None:
            # It's the first time we hear from this peer, need to figure out which headers to
            # get from it.  We can't simply fetch headers starting from our current head
            # number because there may have been a chain reorg, so we fetch some headers prior
            # to our head from the peer, and insert any missing ones in our DB, essentially
            # making our canonical chain identical to the peer's up to
            # chain_head.block_number.
            oldest_ancestor_to_consider = max(
                0, chain_head.block_number - peer.max_headers_fetch + 1)
            try:
                headers = await self.fetch_headers(oldest_ancestor_to_consider,
                                                   peer)
            except EmptyGetBlockHeadersReply:
                raise LESAnnouncementProcessingError(
                    "No common ancestors found between us and {}".format(peer))
            except TooManyTimeouts:
                raise LESAnnouncementProcessingError(
                    "Too many timeouts when fetching headers from {}".format(
                        peer))
            for header in headers:
                await self.chaindb.coro_persist_header_to_db(header)
            start_block = chain_head.block_number
        else:
            start_block = last_peer_announcement.block_number - head_info.reorg_depth
        return start_block

    # TODO: Distribute requests among our peers, ensuring the selected peer has the info we want
    # and respecting the flow control rules.
    async def process_announcement(self, peer: LESPeer,
                                   head_info: les.HeadInfo) -> None:
        if await self.chaindb.coro_header_exists(head_info.block_hash):
            self.logger.debug(
                "Skipping processing of %s from %s as head has already been fetched",
                head_info, peer)
            return

        start_block = await self.get_sync_start_block(peer, head_info)
        while start_block < head_info.block_number:
            try:
                # We should use "start_block + 1" here, but we always re-fetch the last synced
                # block to work around https://github.com/ethereum/go-ethereum/issues/15447
                batch = await self.fetch_headers(start_block, peer)
            except TooManyTimeouts:
                raise LESAnnouncementProcessingError(
                    "Too many timeouts when fetching headers from {}".format(
                        peer))
            for header in batch:
                await self.chaindb.coro_persist_header_to_db(header)
                start_block = header.block_number
            self.logger.info("synced headers up to #%s", start_block)

    async def stop(self):
        self.logger.info("Stopping LightChain...")
        self.cancel_token.trigger()
        self.logger.debug("Waiting for all pending tasks to finish...")
        await self.wait_until_finished()
        self.logger.debug("LightChain finished")

    async def get_canonical_block_by_number(self,
                                            block_number: int) -> BaseBlock:
        """Return the block with the given number from the canonical chain.

        Raises BlockNotFound if it is not found.
        """
        try:
            block_hash = await self.chaindb.coro_lookup_block_hash(block_number
                                                                   )
        except KeyError:
            raise BlockNotFound(
                "No block with number {} found on local chain".format(
                    block_number))
        return await self.get_block_by_hash(block_hash)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_block_by_hash(self, block_hash: bytes) -> BaseBlock:
        peer = await self.get_best_peer()
        try:
            header = await self.chaindb.coro_get_block_header_by_hash(
                block_hash)
        except BlockNotFound:
            self.logger.debug("Fetching header %s from %s",
                              encode_hex(block_hash), peer)
            header = await peer.get_block_header_by_hash(
                block_hash, self.cancel_token)

        self.logger.debug("Fetching block %s from %s", encode_hex(block_hash),
                          peer)
        body = await peer.get_block_by_hash(block_hash, self.cancel_token)
        block_class = self.get_vm_class_for_block_number(
            header.block_number).get_block_class()
        transactions = [
            block_class.transaction_class.from_base_transaction(tx)
            for tx in body.transactions
        ]
        return block_class(
            header=header,
            transactions=transactions,
            uncles=body.uncles,
        )

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_receipts(self, block_hash: bytes) -> List[Receipt]:
        peer = await self.get_best_peer()
        self.logger.debug("Fetching %s receipts from %s",
                          encode_hex(block_hash), peer)
        return await peer.get_receipts(block_hash, self.cancel_token)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_account(self, block_hash: bytes, address: bytes) -> Account:
        peer = await self.get_best_peer()
        return await peer.get_account(block_hash, address, self.cancel_token)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_contract_code(self, block_hash: bytes, key: bytes) -> bytes:
        peer = await self.get_best_peer()
        return await peer.get_contract_code(block_hash, key, self.cancel_token)
예제 #15
0
class PeerPool:
    """PeerPool attempts to keep connections to at least min_peers on the given network."""
    logger = logging.getLogger("evm.p2p.peer.PeerPool")
    min_peers = 2
    _connect_loop_sleep = 2

    def __init__(self,
                 peer_class: Type[BasePeer],
                 chaindb: ChainDB,
                 network_id: int,
                 privkey: datatypes.PrivateKey,
                 ) -> None:
        self.peer_class = peer_class
        self.chaindb = chaindb
        self.network_id = network_id
        self.privkey = privkey
        self.connected_nodes = {}  # type: Dict[Node, BasePeer]
        self.cancel_token = CancelToken('PeerPool')
        self._subscribers = []  # type: List[PeerPoolSubscriber]

    def subscribe(self, subscriber: PeerPoolSubscriber) -> None:
        self._subscribers.append(subscriber)
        for peer in self.connected_nodes.values():
            subscriber.register_peer(peer)

    def unsubscribe(self, subscriber: PeerPoolSubscriber) -> None:
        if subscriber in self._subscribers:
            self._subscribers.remove(subscriber)

    async def get_nodes_to_connect(self) -> List[Node]:
        # TODO: This should use the Discovery service to lookup nodes to connect to, but our
        # current implementation only supports v4 and with that it takes an insane amount of time
        # to find any LES nodes with the same network ID as us, so for now we hard-code some nodes
        # that seem to have a good uptime.
        from evm.chains.ropsten import RopstenChain
        from evm.chains.mainnet import MainnetChain
        if self.network_id == MainnetChain.network_id:
            return [
                Node(
                    keys.PublicKey(decode_hex("1118980bf48b0a3640bdba04e0fe78b1add18e1cd99bf22d53daac1fd9972ad650df52176e7c7d89d1114cfef2bc23a2959aa54998a46afcf7d91809f0855082")),  # noqa: E501
                    Address("52.74.57.123", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("78de8a0916848093c73790ead81d1928bec737d565119932b98c6b100d944b7a95e94f847f689fc723399d2e31129d182f7ef3863f2b4c820abbf3ab2722344d")),  # noqa: E501
                    Address("191.235.84.50", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("ddd81193df80128880232fc1deb45f72746019839589eeb642d3d44efbb8b2dda2c1a46a348349964a6066f8afb016eb2a8c0f3c66f32fadf4370a236a4b5286")),  # noqa: E501
                    Address("52.231.202.145", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("3f1d12044546b76342d59d4a05532c14b85aa669704bfe1f864fe079415aa2c02d743e03218e57a33fb94523adb54032871a6c51b2cc5514cb7c7e35b3ed0a99")),  # noqa: E501
                    Address("13.93.211.84", 30303, 30303)),
            ]
        elif self.network_id == RopstenChain.network_id:
            return [
                Node(
                    keys.PublicKey(decode_hex("88c2b24429a6f7683fbfd06874ae3f1e7c8b4a5ffb846e77c705ba02e2543789d66fc032b6606a8d8888eb6239a2abe5897ce83f78dcdcfcb027d6ea69aa6fe9")),  # noqa: E501
                    Address("163.172.157.61", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("a1ef9ba5550d5fac27f7cbd4e8d20a643ad75596f307c91cd6e7f85b548b8a6bf215cca436d6ee436d6135f9fe51398f8dd4c0bd6c6a0c332ccb41880f33ec12")),  # noqa: E501
                    Address("51.15.218.125", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("e80276aabb7682a4a659f4341c1199de79d91a2e500a6ee9bed16ed4ce927ba8d32ba5dea357739ffdf2c5bcc848d3064bb6f149f0b4249c1f7e53f8bf02bfc8")),  # noqa: E501
                    Address("51.15.39.57", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("584c0db89b00719e9e7b1b5c32a4a8942f379f4d5d66bb69f9c7fa97fa42f64974e7b057b35eb5a63fd7973af063f9a1d32d8c60dbb4854c64cb8ab385470258")),  # noqa: E501
                    Address("51.15.35.2", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("d40871fc3e11b2649700978e06acd68a24af54e603d4333faecb70926ca7df93baa0b7bf4e927fcad9a7c1c07f9b325b22f6d1730e728314d0e4e6523e5cebc2")),  # noqa: E501
                    Address("51.15.132.235", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("482484b9198530ee2e00db89791823244ca41dcd372242e2e1297dd06f6d8dd357603960c5ad9cc8dc15fcdf0e4edd06b7ad7db590e67a0b54f798c26581ebd7")),  # noqa: E501
                    Address("51.15.75.138", 30303, 30303)),
            ]
        else:
            raise ValueError("Unknown network_id: {}".format(self.network_id))

    async def run(self):
        self.logger.info("Running PeerPool...")
        while not self.cancel_token.triggered:
            try:
                await self.maybe_connect_to_more_peers()
            except:  # noqa: E722
                # Most unexpected errors should be transient, so we log and restart from scratch.
                self.logger.error("Unexpected error (%s), restarting", traceback.format_exc())
                await self.stop_all_peers()
            # Wait self._connect_loop_sleep seconds, unless we're asked to stop.
            await asyncio.wait([self.cancel_token.wait()], timeout=self._connect_loop_sleep)

    async def stop_all_peers(self):
        self.logger.info("Stopping all peers ...")
        await asyncio.gather(
            *[peer.stop() for peer in self.connected_nodes.values()])

    async def stop(self):
        self.cancel_token.trigger()
        await self.stop_all_peers()

    async def connect(self, remote: Node) -> BasePeer:
        """
        Connect to the given remote and return a Peer instance when successful.
        Returns None if the remote is unreachable, times out or is useless.
        """
        if remote in self.connected_nodes:
            self.logger.debug("Skipping %s; already connected to it", remote)
            return None
        expected_exceptions = (
            UnreachablePeer, asyncio.TimeoutError, PeerConnectionLost, HandshakeFailure)
        try:
            self.logger.info("Connecting to %s...", remote)
            # TODO: Use asyncio.wait() and our cancel_token here to cancel in case the token is
            # triggered.
            peer = await asyncio.wait_for(
                handshake(remote, self.privkey, self.peer_class, self.chaindb, self.network_id),
                HANDSHAKE_TIMEOUT)
            return peer
        except expected_exceptions as e:
            self.logger.info("Could not complete handshake with %s: %s", remote, repr(e))
        except Exception:
            self.logger.warning("Unexpected error during auth/p2p handshake with %s: %s",
                                remote, traceback.format_exc())
        return None

    async def maybe_connect_to_more_peers(self):
        """Connect to more peers if we're not yet connected to at least self.min_peers."""
        if len(self.connected_nodes) >= self.min_peers:
            self.logger.debug(
                "Already connected to %s peers: %s; sleeping",
                len(self.connected_nodes),
                [remote for remote in self.connected_nodes])
            return

        for node in await self.get_nodes_to_connect():
            # TODO: Consider changing connect() to raise an exception instead of returning None,
            # as discussed in
            # https://github.com/pipermerriam/py-evm/pull/139#discussion_r152067425
            peer = await self.connect(node)
            if peer is not None:
                self.logger.info("Successfully connected to %s", peer)
                asyncio.ensure_future(peer.run(finished_callback=self._peer_finished))
                self.connected_nodes[peer.remote] = peer
                for subscriber in self._subscribers:
                    subscriber.register_peer(peer)

    def _peer_finished(self, peer: BasePeer) -> None:
        """Remove the given peer from our list of connected nodes.
        This is passed as a callback to be called when a peer finishes.
        """
        if peer.remote in self.connected_nodes:
            self.connected_nodes.pop(peer.remote)

    @property
    def peers(self) -> List[BasePeer]:
        return list(self.connected_nodes.values())
예제 #16
0
class BasePeer:
    logger = logging.getLogger("evm.p2p.peer.Peer")
    conn_idle_timeout = CONN_IDLE_TIMEOUT
    reply_timeout = REPLY_TIMEOUT
    # Must be defined in subclasses. All items here must be Protocol classes representing
    # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc).
    _supported_sub_protocols = []  # type: List[Type[protocol.Protocol]]
    # FIXME: Must be configurable.
    listen_port = 30303
    # Will be set upon the successful completion of a P2P handshake.
    sub_proto = None  # type: protocol.Protocol

    def __init__(self,
                 remote: Node,
                 privkey: datatypes.PrivateKey,
                 reader: asyncio.StreamReader,
                 writer: asyncio.StreamWriter,
                 aes_secret: bytes,
                 mac_secret: bytes,
                 egress_mac: PreImage,
                 ingress_mac: PreImage,
                 chaindb: ChainDB,
                 network_id: int,
                 ) -> None:
        self._finished = asyncio.Event()
        self.remote = remote
        self.privkey = privkey
        self.reader = reader
        self.writer = writer
        self.base_protocol = P2PProtocol(self)
        self.chaindb = chaindb
        self.network_id = network_id
        self.sub_proto_msg_queue = asyncio.Queue()  # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]]  # noqa: E501
        self.cancel_token = CancelToken('Peer')

        self.egress_mac = egress_mac
        self.ingress_mac = ingress_mac
        # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
        iv = b"\x00" * 16
        aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend())
        self.aes_enc = aes_cipher.encryptor()
        self.aes_dec = aes_cipher.decryptor()
        mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend())
        self.mac_enc = mac_cipher.encryptor().update

    def send_sub_proto_handshake(self):
        raise NotImplementedError()

    def process_sub_proto_handshake(
            self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        raise NotImplementedError()

    async def do_sub_proto_handshake(self):
        """Perform the handshake for the sub-protocol agreed with the remote peer.

        Raises HandshakeFailure if the handshake is not successful.
        """
        self.send_sub_proto_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the sub-proto handshake.
            raise HandshakeFailure(
                "{} disconnected before completing sub-proto handshake: {}".format(
                    self, msg['reason_name']))
        self.process_sub_proto_handshake(cmd, msg)
        self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote)

    async def do_p2p_handshake(self):
        """Perform the handshake for the P2P base protocol.

        Raises HandshakeFailure if the handshake is not successful.
        """
        self.base_protocol.send_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the initial P2P handshake.
            raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
                self, msg['reason_name']))
        self.process_p2p_handshake(cmd, msg)

    async def read_sub_proto_msg(
            self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        """Read the next sub-protocol message from the queue.

        Raises OperationCancelled if the peer has been disconnected.
        """
        combined_token = self.cancel_token.chain(cancel_token)
        return await wait_with_token(self.sub_proto_msg_queue.get(), combined_token)

    @property
    def genesis(self) -> BlockHeader:
        genesis_hash = self.chaindb.lookup_block_hash(GENESIS_BLOCK_NUMBER)
        return self.chaindb.get_block_header_by_hash(genesis_hash)

    @property
    def _local_chain_info(self) -> 'ChainInfo':
        genesis = self.genesis
        head = self.chaindb.get_canonical_head()
        total_difficulty = self.chaindb.get_score(head.hash)
        return ChainInfo(
            block_number=head.block_number,
            block_hash=head.hash,
            total_difficulty=total_difficulty,
            genesis_hash=genesis.hash,
        )

    @property
    def capabilities(self) -> List[Tuple[bytes, int]]:
        return [(klass.name, klass.version) for klass in self._supported_sub_protocols]

    def get_protocol_command_for(self, msg: bytes) -> protocol.Command:
        """Return the Command corresponding to the cmd_id encoded in the given msg."""
        cmd_id = get_devp2p_cmd_id(msg)
        self.logger.debug("Got msg with cmd_id: %s", cmd_id)
        if cmd_id < self.base_protocol.cmd_length:
            proto = self.base_protocol
        elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length:
            proto = self.sub_proto  # type: ignore
        else:
            raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id))
        return proto.cmd_by_id[cmd_id]

    async def read(self, n: int) -> bytes:
        self.logger.debug("Waiting for %s bytes from %s", n, self.remote)
        try:
            return await wait_with_token(
                self.reader.readexactly(n), self.cancel_token, timeout=self.conn_idle_timeout)
        except (asyncio.IncompleteReadError, ConnectionResetError):
            raise PeerConnectionLost("EOF reading from stream")

    async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None:
        try:
            await self.read_loop()
        except OperationCancelled as e:
            self.logger.debug("Peer finished: %s", e)
        except Exception:
            self.logger.error(
                "Unexpected error when handling remote msg: %s", traceback.format_exc())
        finally:
            self._finished.set()
            if finished_callback is not None:
                finished_callback(self)

    def close(self):
        """Close this peer's reader/writer streams.

        This will cause the peer to stop in case it is running.

        If the streams have already been closed, do nothing.
        """
        if self.reader.at_eof():
            return
        self.reader.feed_eof()
        self.writer.close()

    async def stop(self):
        """Disconnect from the remote and flag this peer as finished.

        If the peer is already flagged as finished, do nothing.
        """
        if self._finished.is_set():
            return
        self.cancel_token.trigger()
        self.close()
        await self._finished.wait()
        self.logger.debug("Stopped %s", self)

    async def read_loop(self):
        while True:
            try:
                cmd, msg = await self.read_msg()
            except (PeerConnectionLost, asyncio.TimeoutError) as e:
                self.logger.info(
                    "%s stopped responding (%s), disconnecting", self.remote, repr(e))
                return

            self.process_msg(cmd, msg)

    async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        header_data = await self.read(HEADER_LEN + MAC_LEN)
        header = self.decrypt_header(header_data)
        frame_size = self.get_frame_size(header)
        # The frame_size specified in the header does not include the padding to 16-byte boundary,
        # so need to do this here to ensure we read all the frame's data.
        read_size = roundup_16(frame_size)
        frame_data = await self.read(read_size + MAC_LEN)
        msg = self.decrypt_body(frame_data, frame_size)
        cmd = self.get_protocol_command_for(msg)
        decoded_msg = cmd.decode(msg)
        self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg)
        return cmd, decoded_msg

    def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType):
        """Handle the base protocol (P2P) messages."""
        if isinstance(cmd, Disconnect):
            msg = cast(Dict[str, Any], msg)
            self.logger.debug(
                "%s disconnected; reason given: %s", self, msg['reason_name'])
            self.close()
        elif isinstance(cmd, Ping):
            self.base_protocol.send_pong()
        elif isinstance(cmd, Pong):
            # Currently we don't do anything when we get a pong, but eventually we should
            # update the last time we heard from a peer in our DB (which doesn't exist yet).
            pass
        else:
            raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))

    def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType):
        self.sub_proto_msg_queue.put_nowait((cmd, msg))

    def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        if isinstance(cmd.proto, P2PProtocol):
            self.handle_p2p_msg(cmd, msg)
        else:
            self.handle_sub_proto_msg(cmd, msg)

    def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        msg = cast(Dict[str, Any], msg)
        if not isinstance(cmd, Hello):
            self.disconnect(DisconnectReason.other)
            raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
        remote_capabilities = msg['capabilities']
        self.sub_proto = self.select_sub_protocol(remote_capabilities)
        if self.sub_proto is None:
            self.disconnect(DisconnectReason.useless_peer)
            raise HandshakeFailure(
                "No matching capabilities between us ({}) and {} ({}), disconnecting".format(
                    self.capabilities, self.remote, remote_capabilities))
        self.logger.debug(
            "Finished P2P handshake with %s, using sub-protocol %s",
            self.remote, self.sub_proto)

    def encrypt(self, header: bytes, frame: bytes) -> bytes:
        if len(header) != HEADER_LEN:
            raise ValueError("Unexpected header length: {}".format(len(header)))

        header_ciphertext = self.aes_enc.update(header)
        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext))
        header_mac = self.egress_mac.digest()[:HEADER_LEN]

        frame_ciphertext = self.aes_enc.update(frame)
        self.egress_mac.update(frame_ciphertext)
        fmac_seed = self.egress_mac.digest()[:HEADER_LEN]

        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed))
        frame_mac = self.egress_mac.digest()[:HEADER_LEN]

        return header_ciphertext + header_mac + frame_ciphertext + frame_mac

    def decrypt_header(self, data: bytes) -> bytes:
        if len(data) != HEADER_LEN + MAC_LEN:
            raise ValueError("Unexpected header length: {}".format(len(data)))

        header_ciphertext = data[:HEADER_LEN]
        header_mac = data[HEADER_LEN:]
        mac_secret = self.ingress_mac.digest()[:HEADER_LEN]
        aes = self.mac_enc(mac_secret)[:HEADER_LEN]
        self.ingress_mac.update(sxor(aes, header_ciphertext))
        expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN]
        if not bytes_eq(expected_header_mac, header_mac):
            raise AuthenticationError('Invalid header mac')
        return self.aes_dec.update(header_ciphertext)

    def decrypt_body(self, data: bytes, body_size: int) -> bytes:
        read_size = roundup_16(body_size)
        if len(data) < read_size + MAC_LEN:
            raise ValueError('Insufficient body length; Got {}, wanted {}'.format(
                len(data), (read_size + MAC_LEN)))

        frame_ciphertext = data[:read_size]
        frame_mac = data[read_size:read_size + MAC_LEN]

        self.ingress_mac.update(frame_ciphertext)
        fmac_seed = self.ingress_mac.digest()[:MAC_LEN]
        self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed))
        expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN]
        if not bytes_eq(expected_frame_mac, frame_mac):
            raise AuthenticationError('Invalid frame mac')
        return self.aes_dec.update(frame_ciphertext)[:body_size]

    def get_frame_size(self, header: bytes) -> int:
        # The frame size is encoded in the header as a 3-byte int, so before we unpack we need
        # to prefix it with an extra byte.
        encoded_size = b'\x00' + header[:3]
        (size,) = struct.unpack(b'>I', encoded_size)
        return size

    def send(self, header: bytes, body: bytes) -> None:
        cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int)
        self.logger.debug("Sending msg with cmd_id: %s", cmd_id)
        self.writer.write(self.encrypt(header, body))

    def disconnect(self, reason: DisconnectReason) -> None:
        """Send a disconnect msg to the remote node and stop this Peer.

        :param reason: An item from the DisconnectReason enum.
        """
        if not isinstance(reason, DisconnectReason):
            self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value)
            raise ValueError(
                "Reason must be an item of DisconnectReason, got {}".format(reason))
        self.base_protocol.send_disconnect(reason.value)
        self.close()

    def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
                            ) -> protocol.Protocol:
        """Select the sub-protocol to use when talking to the remote.

        Find the highest version of our supported sub-protocols that is also supported by the
        remote and stores an instance of it (with the appropriate cmd_id offset) in
        self.sub_proto.
        """
        matching_capabilities = set(self.capabilities).intersection(remote_capabilities)
        _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1))
        offset = self.base_protocol.cmd_length
        for proto_class in self._supported_sub_protocols:
            if proto_class.version == highest_matching_version:
                return proto_class(self, offset)
        return None

    def __str__(self):
        return "{} {}".format(self.__class__.__name__, self.remote)