Ejemplo n.º 1
0
async def test_wait_cancel_pending_tasks_on_completion(event_loop):
    token = CancelToken('token')
    token2 = CancelToken('token2')
    chain = token.chain(token2)
    token2.trigger()
    await chain.wait()
    await assert_only_current_task_not_done()
Ejemplo n.º 2
0
async def test_wait_cancel_pending_tasks_on_completion(event_loop):
    token = CancelToken('token')
    token2 = CancelToken('token2')
    chain = token.chain(token2)
    token2.trigger()
    await chain.wait()
    await assert_only_current_task_not_done()
Ejemplo n.º 3
0
    def __init__(self, token: CancelToken = None) -> None:
        self.finished = asyncio.Event()

        base_token = CancelToken(type(self).__name__)
        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)
Ejemplo n.º 4
0
def test_token_chain_trigger_last():
    token = CancelToken('token')
    token2 = CancelToken('token2')
    token3 = CancelToken('token3')
    chain = token.chain(token2).chain(token3)
    assert not chain.triggered
    token3.trigger()
    assert chain.triggered
    assert chain.triggered_token == token3
Ejemplo n.º 5
0
def test_token_chain_trigger_last():
    token = CancelToken('token')
    token2 = CancelToken('token2')
    token3 = CancelToken('token3')
    chain = token.chain(token2).chain(token3)
    assert not chain.triggered
    token3.trigger()
    assert chain.triggered
    assert chain.triggered_token == token3
Ejemplo n.º 6
0
    def __init__(self, token: CancelToken=None) -> None:
        if self.logger is None:
            self.logger = logging.getLogger(self.__module__ + '.' + self.__class__.__name__)

        self._run_lock = asyncio.Lock()
        self.cleaned_up = asyncio.Event()

        base_token = CancelToken(type(self).__name__)
        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)
Ejemplo n.º 7
0
 def __init__(self,
              account_db: BaseDB,
              root_hash: bytes,
              peer_pool: PeerPool,
              token: CancelToken = None) -> None:
     cancel_token = CancelToken('StateDownloader')
     if token is not None:
         cancel_token = cancel_token.chain(token)
     super().__init__(cancel_token)
     self.peer_pool = peer_pool
     self.root_hash = root_hash
     self.scheduler = StateSync(root_hash, account_db)
     self._running_peers = set()  # type: Set[ETHPeer]
     self._peers_with_pending_requests = {}  # type: Dict[ETHPeer, float]
Ejemplo n.º 8
0
 def __init__(self,
              chain: AsyncChain,
              chaindb: AsyncChainDB,
              db: BaseDB,
              peer_pool: PeerPool,
              token: CancelToken = None) -> None:
     cancel_token = CancelToken('FullNodeSyncer')
     if token is not None:
         cancel_token = cancel_token.chain(token)
     super().__init__(cancel_token)
     self.chain = chain
     self.chaindb = chaindb
     self.db = db
     self.peer_pool = peer_pool
Ejemplo n.º 9
0
    def __init__(self, shard: Shard, peer_pool: PeerPool,
                 token: CancelToken) -> None:
        cancel_token = CancelToken("ShardSyncer")
        if token is not None:
            cancel_token = cancel_token.chain(token)
        super().__init__(token)

        self.shard = shard
        self.peer_pool = peer_pool

        self.incoming_collation_queue: asyncio.Queue[
            Collation] = asyncio.Queue()

        self.collations_received_event = asyncio.Event()

        self.start_time = time.time()
Ejemplo n.º 10
0
    async def wait_first(
            self, *futures: Awaitable, token: CancelToken = None, timeout: float = None) -> Any:
        """Wait for the first future to complete, unless we timeout or the token chain is triggered.

        The given token is chained with this service's token, so triggering either will cancel
        this.

        Returns the result of the first future to complete.

        Raises TimeoutError if we timeout or OperationCancelled if the token chain is triggered.

        All pending futures are cancelled before returning.
        """
        if token is None:
            token_chain = self.cancel_token
        else:
            token_chain = token.chain(self.cancel_token)
        return await wait_with_token(*futures, token=token_chain, timeout=timeout)
Ejemplo n.º 11
0
 def __init__(self,
              chaindb: AsyncChainDB,
              peer_pool: PeerPool,
              token: CancelToken = None) -> None:
     cancel_token = CancelToken('ChainSyncer')
     if token is not None:
         cancel_token = cancel_token.chain(token)
     super().__init__(cancel_token)
     self.chaindb = chaindb
     self.peer_pool = peer_pool
     self._running_peers = set()  # type: Set[ETHPeer]
     self._syncing = False
     self._sync_complete = asyncio.Event()
     self._sync_requests = asyncio.Queue()  # type: asyncio.Queue[ETHPeer]
     self._new_headers = asyncio.Queue()  # type: asyncio.Queue[List[BlockHeader]]
     # Those are used by our msg handlers and _download_block_parts() in order to track missing
     # bodies/receipts for a given chain segment.
     self._downloaded_receipts = asyncio.Queue()  # type: asyncio.Queue[List[DownloadedBlockPart]]  # noqa: E501
     self._downloaded_bodies = asyncio.Queue()  # type: asyncio.Queue[List[DownloadedBlockPart]]
Ejemplo n.º 12
0
    def __init__(self,
                 token: CancelToken = None,
                 loop: asyncio.AbstractEventLoop = None) -> None:
        if self.logger is None:
            self.logger = cast(
                TraceLogger,
                logging.getLogger(self.__module__ + '.' +
                                  self.__class__.__name__))

        self._run_lock = asyncio.Lock()
        self.cleaned_up = asyncio.Event()
        self._child_services = []

        self.loop = loop
        base_token = CancelToken(type(self).__name__, loop=loop)

        if token is None:
            self.cancel_token = base_token
        else:
            self.cancel_token = base_token.chain(token)
Ejemplo n.º 13
0
class BasePeer:
    logger = logging.getLogger("p2p.peer.Peer")
    conn_idle_timeout = CONN_IDLE_TIMEOUT
    reply_timeout = REPLY_TIMEOUT
    # Must be defined in subclasses. All items here must be Protocol classes representing
    # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc).
    _supported_sub_protocols = []  # type: List[Type[protocol.Protocol]]
    # FIXME: Must be configurable.
    listen_port = 30303
    # Will be set upon the successful completion of a P2P handshake.
    sub_proto = None  # type: protocol.Protocol

    def __init__(self,
                 remote: Node,
                 privkey: datatypes.PrivateKey,
                 reader: asyncio.StreamReader,
                 writer: asyncio.StreamWriter,
                 aes_secret: bytes,
                 mac_secret: bytes,
                 egress_mac: PreImage,
                 ingress_mac: PreImage,
                 chaindb: AsyncChainDB,
                 network_id: int,
                 ) -> None:
        self._finished = asyncio.Event()
        self.remote = remote
        self.privkey = privkey
        self.reader = reader
        self.writer = writer
        self.base_protocol = P2PProtocol(self)
        self.chaindb = chaindb
        self.network_id = network_id
        self.sub_proto_msg_queue = asyncio.Queue()  # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]]  # noqa: E501
        self.cancel_token = CancelToken('Peer')

        self.egress_mac = egress_mac
        self.ingress_mac = ingress_mac
        # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
        iv = b"\x00" * 16
        aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend())
        self.aes_enc = aes_cipher.encryptor()
        self.aes_dec = aes_cipher.decryptor()
        mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend())
        self.mac_enc = mac_cipher.encryptor().update

    async def send_sub_proto_handshake(self):
        raise NotImplementedError()

    async def process_sub_proto_handshake(
            self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        raise NotImplementedError()

    async def do_sub_proto_handshake(self):
        """Perform the handshake for the sub-protocol agreed with the remote peer.

        Raises HandshakeFailure if the handshake is not successful.
        """
        await self.send_sub_proto_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the sub-proto handshake.
            raise HandshakeFailure(
                "{} disconnected before completing sub-proto handshake: {}".format(
                    self, msg['reason_name']))
        await self.process_sub_proto_handshake(cmd, msg)
        self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote)

    async def do_p2p_handshake(self):
        """Perform the handshake for the P2P base protocol.

        Raises HandshakeFailure if the handshake is not successful.
        """
        self.base_protocol.send_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the initial P2P handshake.
            raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
                self, msg['reason_name']))
        self.process_p2p_handshake(cmd, msg)

    async def read_sub_proto_msg(
            self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        """Read the next sub-protocol message from the queue.

        Raises OperationCancelled if the peer has been disconnected.
        """
        combined_token = self.cancel_token.chain(cancel_token)
        return await wait_with_token(self.sub_proto_msg_queue.get(), combined_token)

    @property
    async def genesis(self) -> BlockHeader:
        genesis_hash = await self.chaindb.coro_lookup_block_hash(GENESIS_BLOCK_NUMBER)
        return await self.chaindb.coro_get_block_header_by_hash(genesis_hash)

    @property
    async def _local_chain_info(self) -> 'ChainInfo':
        genesis = await self.genesis
        head = await self.chaindb.coro_get_canonical_head()
        total_difficulty = await self.chaindb.coro_get_score(head.hash)
        return ChainInfo(
            block_number=head.block_number,
            block_hash=head.hash,
            total_difficulty=total_difficulty,
            genesis_hash=genesis.hash,
        )

    @property
    def capabilities(self) -> List[Tuple[bytes, int]]:
        return [(klass.name, klass.version) for klass in self._supported_sub_protocols]

    def get_protocol_command_for(self, msg: bytes) -> protocol.Command:
        """Return the Command corresponding to the cmd_id encoded in the given msg."""
        cmd_id = get_devp2p_cmd_id(msg)
        self.logger.debug("Got msg with cmd_id: %s", cmd_id)
        if cmd_id < self.base_protocol.cmd_length:
            proto = self.base_protocol
        elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length:
            proto = self.sub_proto  # type: ignore
        else:
            raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id))
        return proto.cmd_by_id[cmd_id]

    async def read(self, n: int) -> bytes:
        self.logger.debug("Waiting for %s bytes from %s", n, self.remote)
        try:
            return await wait_with_token(
                self.reader.readexactly(n), self.cancel_token, timeout=self.conn_idle_timeout)
        except (asyncio.IncompleteReadError, ConnectionResetError):
            raise PeerConnectionLost("EOF reading from stream")

    async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None:
        try:
            await self.read_loop()
        except OperationCancelled as e:
            self.logger.debug("Peer finished: %s", e)
        except Exception:
            self.logger.error(
                "Unexpected error when handling remote msg: %s", traceback.format_exc())
        finally:
            self._finished.set()
            if finished_callback is not None:
                finished_callback(self)

    def close(self):
        """Close this peer's reader/writer streams.

        This will cause the peer to stop in case it is running.

        If the streams have already been closed, do nothing.
        """
        if self.reader.at_eof():
            return
        self.reader.feed_eof()
        self.writer.close()

    async def stop(self):
        """Disconnect from the remote and flag this peer as finished.

        If the peer is already flagged as finished, do nothing.
        """
        if self._finished.is_set():
            return
        self.cancel_token.trigger()
        self.close()
        await self._finished.wait()
        self.logger.debug("Stopped %s", self)

    async def read_loop(self):
        while True:
            try:
                cmd, msg = await self.read_msg()
            except (PeerConnectionLost, asyncio.TimeoutError) as e:
                self.logger.info(
                    "%s stopped responding (%s), disconnecting", self.remote, repr(e))
                return

            self.process_msg(cmd, msg)

    async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        header_data = await self.read(HEADER_LEN + MAC_LEN)
        header = self.decrypt_header(header_data)
        frame_size = self.get_frame_size(header)
        # The frame_size specified in the header does not include the padding to 16-byte boundary,
        # so need to do this here to ensure we read all the frame's data.
        read_size = roundup_16(frame_size)
        frame_data = await self.read(read_size + MAC_LEN)
        msg = self.decrypt_body(frame_data, frame_size)
        cmd = self.get_protocol_command_for(msg)
        decoded_msg = cmd.decode(msg)
        self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg)
        return cmd, decoded_msg

    def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        """Handle the base protocol (P2P) messages."""
        if isinstance(cmd, Disconnect):
            msg = cast(Dict[str, Any], msg)
            self.logger.debug(
                "%s disconnected; reason given: %s", self, msg['reason_name'])
            self.close()
        elif isinstance(cmd, Ping):
            self.base_protocol.send_pong()
        elif isinstance(cmd, Pong):
            # Currently we don't do anything when we get a pong, but eventually we should
            # update the last time we heard from a peer in our DB (which doesn't exist yet).
            pass
        else:
            raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))

    def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        self.sub_proto_msg_queue.put_nowait((cmd, msg))

    def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        if isinstance(cmd.proto, P2PProtocol):
            self.handle_p2p_msg(cmd, msg)
        else:
            self.handle_sub_proto_msg(cmd, msg)

    def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        msg = cast(Dict[str, Any], msg)
        if not isinstance(cmd, Hello):
            self.disconnect(DisconnectReason.other)
            raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
        remote_capabilities = msg['capabilities']
        self.sub_proto = self.select_sub_protocol(remote_capabilities)
        if self.sub_proto is None:
            self.disconnect(DisconnectReason.useless_peer)
            raise HandshakeFailure(
                "No matching capabilities between us ({}) and {} ({}), disconnecting".format(
                    self.capabilities, self.remote, remote_capabilities))
        self.logger.debug(
            "Finished P2P handshake with %s, using sub-protocol %s",
            self.remote, self.sub_proto)

    def encrypt(self, header: bytes, frame: bytes) -> bytes:
        if len(header) != HEADER_LEN:
            raise ValueError("Unexpected header length: {}".format(len(header)))

        header_ciphertext = self.aes_enc.update(header)
        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext))
        header_mac = self.egress_mac.digest()[:HEADER_LEN]

        frame_ciphertext = self.aes_enc.update(frame)
        self.egress_mac.update(frame_ciphertext)
        fmac_seed = self.egress_mac.digest()[:HEADER_LEN]

        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed))
        frame_mac = self.egress_mac.digest()[:HEADER_LEN]

        return header_ciphertext + header_mac + frame_ciphertext + frame_mac

    def decrypt_header(self, data: bytes) -> bytes:
        if len(data) != HEADER_LEN + MAC_LEN:
            raise ValueError("Unexpected header length: {}".format(len(data)))

        header_ciphertext = data[:HEADER_LEN]
        header_mac = data[HEADER_LEN:]
        mac_secret = self.ingress_mac.digest()[:HEADER_LEN]
        aes = self.mac_enc(mac_secret)[:HEADER_LEN]
        self.ingress_mac.update(sxor(aes, header_ciphertext))
        expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN]
        if not bytes_eq(expected_header_mac, header_mac):
            raise AuthenticationError('Invalid header mac')
        return self.aes_dec.update(header_ciphertext)

    def decrypt_body(self, data: bytes, body_size: int) -> bytes:
        read_size = roundup_16(body_size)
        if len(data) < read_size + MAC_LEN:
            raise ValueError('Insufficient body length; Got {}, wanted {}'.format(
                len(data), (read_size + MAC_LEN)))

        frame_ciphertext = data[:read_size]
        frame_mac = data[read_size:read_size + MAC_LEN]

        self.ingress_mac.update(frame_ciphertext)
        fmac_seed = self.ingress_mac.digest()[:MAC_LEN]
        self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed))
        expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN]
        if not bytes_eq(expected_frame_mac, frame_mac):
            raise AuthenticationError('Invalid frame mac')
        return self.aes_dec.update(frame_ciphertext)[:body_size]

    def get_frame_size(self, header: bytes) -> int:
        # The frame size is encoded in the header as a 3-byte int, so before we unpack we need
        # to prefix it with an extra byte.
        encoded_size = b'\x00' + header[:3]
        (size,) = struct.unpack(b'>I', encoded_size)
        return size

    def send(self, header: bytes, body: bytes) -> None:
        cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int)
        self.logger.debug("Sending msg with cmd_id: %s", cmd_id)
        self.writer.write(self.encrypt(header, body))

    def disconnect(self, reason: DisconnectReason) -> None:
        """Send a disconnect msg to the remote node and stop this Peer.

        :param reason: An item from the DisconnectReason enum.
        """
        if not isinstance(reason, DisconnectReason):
            self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value)
            raise ValueError(
                "Reason must be an item of DisconnectReason, got {}".format(reason))
        self.base_protocol.send_disconnect(reason.value)
        self.close()

    def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
                            ) -> protocol.Protocol:
        """Select the sub-protocol to use when talking to the remote.

        Find the highest version of our supported sub-protocols that is also supported by the
        remote and stores an instance of it (with the appropriate cmd_id offset) in
        self.sub_proto.
        """
        matching_capabilities = set(self.capabilities).intersection(remote_capabilities)
        _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1))
        offset = self.base_protocol.cmd_length
        for proto_class in self._supported_sub_protocols:
            if proto_class.version == highest_matching_version:
                return proto_class(self, offset)
        return None

    def __str__(self):
        return "{} {}".format(self.__class__.__name__, self.remote)
Ejemplo n.º 14
0
def test_token_chain_event_loop_mismatch():
    token = CancelToken('token')
    token2 = CancelToken('token2', loop=asyncio.new_event_loop())
    with pytest.raises(EventLoopMismatch):
        token.chain(token2)
Ejemplo n.º 15
0
class ChainSyncer(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.chain.ChainSyncer")
    # We'll only sync if we are connected to at least min_peers_to_sync.
    min_peers_to_sync = 2
    # TODO: Instead of a fixed timeout, we should use a variable one that gets adjusted based on
    # the round-trip times from our download requests.
    _reply_timeout = 60

    def __init__(self,
                 chaindb: AsyncChainDB,
                 peer_pool: PeerPool,
                 token: CancelToken = None) -> None:
        self.chaindb = chaindb
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.cancel_token = CancelToken('ChainSyncer')
        if token is not None:
            self.cancel_token = self.cancel_token.chain(token)
        self._running_peers = set()  # type: Set[ETHPeer]
        self._syncing = False
        self._sync_complete = asyncio.Event()
        self._sync_requests = asyncio.Queue()  # type: asyncio.Queue[ETHPeer]
        self._new_headers = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        self._downloaded_receipts = asyncio.Queue(
        )  # type: asyncio.Queue[List[bytes]]
        self._downloaded_bodies = asyncio.Queue(
        )  # type: asyncio.Queue[List[Tuple[bytes, bytes]]]

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))
        highest_td_peer = max(
            [cast(ETHPeer, peer) for peer in self.peer_pool.peers],
            key=operator.attrgetter('head_td'))
        self._sync_requests.put_nowait(highest_td_peer)

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break

            pending_msgs = peer.sub_proto_msg_queue.qsize()
            if pending_msgs:
                self.logger.debug(
                    "Read %s msg from %s's queue; %d msgs pending", cmd, peer,
                    pending_msgs)

            # Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main
            # loop can keep processing msgs, and that's why we use ensure_future() instead of
            # awaiting for it to finish here.
            asyncio.ensure_future(self.handle_msg(peer, cmd, msg))

    async def handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                         msg: protocol._DecodedMsgType) -> None:
        try:
            await self._handle_msg(peer, cmd, msg)
        except OperationCancelled:
            # Silently swallow OperationCancelled exceptions because we run unsupervised (i.e.
            # with ensure_future()). Our caller will also get an OperationCancelled anyway, and
            # there it will be handled.
            pass
        except Exception:
            self.logger.exception(
                "Unexpected error when processing msg from %s", peer)

    async def run(self) -> None:
        while True:
            peer_or_finished = await wait_with_token(
                self._sync_requests.get(),
                self._sync_complete.wait(),
                token=self.cancel_token)

            if self._sync_complete.is_set():
                return

            # Since self._sync_complete is not set, peer_or_finished can only be a ETHPeer
            # instance.
            asyncio.ensure_future(self.sync(peer_or_finished))

    async def sync(self, peer: ETHPeer) -> None:
        if self._syncing:
            self.logger.debug(
                "Got a NewBlock or a new peer, but already syncing so doing nothing"
            )
            return
        elif len(self._running_peers) < self.min_peers_to_sync:
            self.logger.warn(
                "Connected to less peers (%d) than the minimum (%d) required to sync, "
                "doing nothing", len(self._running_peers),
                self.min_peers_to_sync)
            return

        self._syncing = True
        try:
            await self._sync(peer)
        except OperationCancelled:
            pass
        finally:
            self._syncing = False

    async def _sync(self, peer: ETHPeer) -> None:
        head = await self.chaindb.coro_get_canonical_head()
        head_td = await self.chaindb.coro_get_score(head.hash)
        if peer.head_td <= head_td:
            self.logger.info(
                "Head TD (%d) announced by %s not higher than ours (%d), not syncing",
                peer.head_td, peer, head_td)
            return

        self.logger.info("Starting sync with %s", peer)
        # FIXME: Fetch a batch of headers, in reverse order, starting from our current head, and
        # find the common ancestor between our chain and the peer's.
        start_at = max(0, head.block_number - eth.MAX_HEADERS_FETCH)
        while True:
            self.logger.info("Fetching chain segment starting at #%d",
                             start_at)
            peer.sub_proto.send_get_block_headers(start_at,
                                                  eth.MAX_HEADERS_FETCH,
                                                  reverse=False)
            try:
                headers = await wait_with_token(self._new_headers.get(),
                                                peer.wait_until_finished(),
                                                token=self.cancel_token,
                                                timeout=self._reply_timeout)
            except TimeoutError:
                self.logger.warn(
                    "Timeout waiting for header batch from %s, aborting sync",
                    peer)
                await peer.stop()
                break

            if peer.is_finished():
                self.logger.info("%s disconnected, aborting sync", peer)
                break

            self.logger.info("Got headers segment starting at #%d", start_at)

            # TODO: Process headers for consistency.

            await self._download_block_parts(
                [header for header in headers if not _is_body_empty(header)],
                self.request_bodies, self._downloaded_bodies, _body_key,
                'body')

            self.logger.info(
                "Got block bodies for chain segment starting at #%d", start_at)

            missing_receipts = [
                header for header in headers if not _is_receipts_empty(header)
            ]
            # Post-Byzantium blocks may have identical receipt roots (e.g. when they have the same
            # number of transactions and all succeed/failed: ropsten blocks 2503212 and 2503284),
            # so we do this to avoid requesting the same receipts multiple times.
            missing_receipts = list(unique(missing_receipts,
                                           key=_receipts_key))
            await self._download_block_parts(missing_receipts,
                                             self.request_receipts,
                                             self._downloaded_receipts,
                                             _receipts_key, 'receipt')

            self.logger.info(
                "Got block receipts for chain segment starting at #%d",
                start_at)

            for header in headers:
                await self.chaindb.coro_persist_header(header)
                start_at = header.block_number + 1

            self.logger.info("Imported chain segment, new head: #%d",
                             start_at - 1)
            head = await self.chaindb.coro_get_canonical_head()
            if head.hash == peer.head_hash:
                self.logger.info("Chain sync with %s completed", peer)
                self._sync_complete.set()
                break

    async def _download_block_parts(
            self,
            headers: List[BlockHeader],
            request_func: Callable[[List[BlockHeader]], int],
            download_queue:
        'Union[asyncio.Queue[List[bytes]], asyncio.Queue[List[Tuple[bytes, bytes]]]]',  # noqa: E501
            key_func: Callable[[BlockHeader], Union[bytes, Tuple[bytes,
                                                                 bytes]]],
            part_name: str) -> None:
        missing = headers.copy()
        # The ETH protocol doesn't guarantee that we'll get all body parts requested, so we need
        # to keep track of the number of pending replies and missing items to decide when to retry
        # them. See request_receipts() for more info.
        pending_replies = request_func(missing)
        while missing:
            if pending_replies == 0:
                pending_replies = request_func(missing)

            try:
                downloaded = await wait_with_token(download_queue.get(),
                                                   token=self.cancel_token,
                                                   timeout=self._reply_timeout)
            except TimeoutError:
                pending_replies = request_func(missing)
                continue

            pending_replies -= 1
            unexpected = set(downloaded).difference(
                [key_func(header) for header in missing])
            for item in unexpected:
                self.logger.warn("Got unexpected %s: %s", part_name,
                                 unexpected)
            missing = [
                header for header in missing
                if key_func(header) not in downloaded
            ]

    def _request_block_parts(
            self, headers: List[BlockHeader],
            request_func: Callable[[ETHPeer, List[BlockHeader]], None]) -> int:
        length = math.ceil(len(headers) / len(self.peer_pool.peers))
        batches = list(partition_all(length, headers))
        for peer, batch in zip(self.peer_pool.peers, batches):
            request_func(cast(ETHPeer, peer), batch)
        return len(batches)

    def _send_get_block_bodies(self, peer: ETHPeer,
                               headers: List[BlockHeader]) -> None:
        self.logger.debug("Requesting %d block bodies to %s", len(headers),
                          peer)
        peer.sub_proto.send_get_block_bodies(
            [header.hash for header in headers])

    def _send_get_receipts(self, peer: ETHPeer,
                           headers: List[BlockHeader]) -> None:
        self.logger.debug("Requesting %d block receipts to %s", len(headers),
                          peer)
        peer.sub_proto.send_get_receipts([header.hash for header in headers])

    def request_bodies(self, headers: List[BlockHeader]) -> int:
        """Ask our peers for bodies for the given headers.

        See request_receipts() for details of how this is done.
        """
        return self._request_block_parts(headers, self._send_get_block_bodies)

    def request_receipts(self, headers: List[BlockHeader]) -> int:
        """Ask our peers for receipts for the given headers.

        We partition the given list of headers in batches and request each to one of our connected
        peers. This is done because geth enforces a byte-size cap when replying to a GetReceipts
        msg, and we then need to re-request the items that didn't fit, so by splitting the
        requests across all our peers we reduce the likelyhood of having to make multiple
        serialized requests to ask for missing items (which happens quite frequently in practice).

        Returns the number of requests made.
        """
        return self._request_block_parts(headers, self._send_get_receipts)

    async def wait_until_finished(self) -> None:
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        self.logger.info("Waiting for %d running peers to finish",
                         len(self._running_peers))
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            await asyncio.sleep(0.1)
        else:
            self.logger.info(
                "Waited too long for peers to finish, exiting anyway")

    async def stop(self) -> None:
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        await self.wait_until_finished()

    async def _handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                          msg: protocol._DecodedMsgType) -> None:
        if isinstance(cmd, eth.BlockHeaders):
            msg = cast(List[BlockHeader], msg)
            self.logger.debug("Got BlockHeaders from %d to %d",
                              msg[0].block_number, msg[-1].block_number)
            self._new_headers.put_nowait(msg)
        elif isinstance(cmd, eth.BlockBodies):
            await self._handle_block_bodies(peer, cast(List[eth.BlockBody],
                                                       msg))
        elif isinstance(cmd, eth.Receipts):
            await self._handle_block_receipts(
                peer, cast(List[List[eth.Receipt]], msg))
        elif isinstance(cmd, eth.NewBlock):
            msg = cast(Dict[str, Any], msg)
            header = msg['block'][0]
            actual_head = header.parent_hash
            actual_td = msg['total_difficulty'] - header.difficulty
            if actual_td > peer.head_td:
                peer.head_hash = actual_head
                peer.head_td = actual_td
                self._sync_requests.put_nowait(peer)
        elif isinstance(cmd, eth.Transactions):
            # TODO: Figure out what to do with those.
            pass
        elif isinstance(cmd, eth.NewBlockHashes):
            # TODO: Figure out what to do with those.
            pass
        else:
            # TODO: There are other msg types we'll want to handle here, but for now just log them
            # as a warning so we don't forget about it.
            self.logger.warn("Got unexpected msg: %s (%s)", cmd, msg)

    async def _handle_block_receipts(
            self, peer: ETHPeer, receipts: List[List[eth.Receipt]]) -> None:
        self.logger.debug("Got Receipts for %d blocks from %s", len(receipts),
                          peer)
        loop = asyncio.get_event_loop()
        iterator = map(make_trie_root_and_nodes, receipts)
        receipts_tries = await wait_with_token(loop.run_in_executor(
            None, list, iterator),
                                               token=self.cancel_token)
        for receipt_root, trie_dict_data in receipts_tries:
            await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
        receipt_roots = [receipt_root for receipt_root, _ in receipts_tries]
        self._downloaded_receipts.put_nowait(receipt_roots)

    async def _handle_block_bodies(self, peer: ETHPeer,
                                   bodies: List[eth.BlockBody]) -> None:
        self.logger.debug("Got Bodies for %d blocks from %s", len(bodies),
                          peer)
        loop = asyncio.get_event_loop()
        iterator = map(make_trie_root_and_nodes,
                       [body.transactions for body in bodies])
        transactions_tries = await wait_with_token(loop.run_in_executor(
            None, list, iterator),
                                                   token=self.cancel_token)
        body_keys = []
        for (body, (tx_root, trie_dict_data)) in zip(bodies,
                                                     transactions_tries):
            await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
            uncles_hash = await self.chaindb.coro_persist_uncles(body.uncles)
            body_keys.append((tx_root, uncles_hash))
        self._downloaded_bodies.put_nowait(body_keys)
Ejemplo n.º 16
0
class DiscoveryProtocol(asyncio.DatagramProtocol):
    """A Kademlia-like protocol to discover RLPx nodes."""
    logger = logging.getLogger("p2p.discovery.DiscoveryProtocol")
    transport: asyncio.DatagramTransport = None
    _max_neighbours_per_packet_cache = None

    def __init__(self, privkey: datatypes.PrivateKey, address: kademlia.Address,
                 bootstrap_nodes: List[str]) -> None:
        self.privkey = privkey
        self.address = address
        self.bootstrap_nodes = [kademlia.Node.from_uri(node) for node in bootstrap_nodes]
        self.this_node = kademlia.Node(self.pubkey, address)
        self.kademlia = kademlia.KademliaProtocol(self.this_node, wire=self)
        self.cancel_token = CancelToken('DiscoveryProtocol')

    async def lookup_random(self, cancel_token: CancelToken) -> List[kademlia.Node]:
        return await self.kademlia.lookup_random(self.cancel_token.chain(cancel_token))

    def get_random_nodes(self, count: int) -> Generator[kademlia.Node, None, None]:
        return self.kademlia.routing.get_random_nodes(count)

    @property
    def pubkey(self) -> datatypes.PublicKey:
        return self.privkey.public_key

    def _get_handler(self, cmd: Command) -> Callable[[kademlia.Node, List[Any], bytes], None]:
        if cmd == CMD_PING:
            return self.recv_ping
        elif cmd == CMD_PONG:
            return self.recv_pong
        elif cmd == CMD_FIND_NODE:
            return self.recv_find_node
        elif cmd == CMD_NEIGHBOURS:
            return self.recv_neighbours
        else:
            raise ValueError("Unknwon command: {}".format(cmd))

    def _get_max_neighbours_per_packet(self):
        if self._max_neighbours_per_packet_cache is not None:
            return self._max_neighbours_per_packet_cache
        self._max_neighbours_per_packet_cache = _get_max_neighbours_per_packet()
        return self._max_neighbours_per_packet_cache

    def connection_made(self, transport):
        self.transport = transport

    async def bootstrap(self):
        while self.transport is None:
            # FIXME: Instead of sleeping here to wait until connection_made() is called to set
            # .transport we should instead only call it after we know it's been set.
            await asyncio.sleep(1)
        self.logger.debug("boostrapping with %s", self.bootstrap_nodes)
        await self.kademlia.bootstrap(self.bootstrap_nodes, self.cancel_token)

    # FIXME: Enable type checking here once we have a mypy version that
    # includes the fix for https://github.com/python/typeshed/pull/1740
    def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:  # type: ignore
        ip_address, udp_port = addr
        # XXX: For now we simply discard all v5 messages. The prefix below is what geth uses to
        # identify them:
        # https://github.com/ethereum/go-ethereum/blob/c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5/p2p/discv5/udp.go#L149  # noqa: E501
        if text_if_str(to_bytes, data).startswith(b"temporary discovery v5"):
            self.logger.debug("Got discovery v5 msg, discarding")
            return

        self.receive(kademlia.Address(ip_address, udp_port), data)  # type: ignore

    def error_received(self, exc: Exception) -> None:
        self.logger.error('error received: %s', exc)

    def send(self, node: kademlia.Node, message: bytes) -> None:
        self.transport.sendto(message, (node.address.ip, node.address.udp_port))

    async def stop(self):
        self.logger.info('stopping discovery')
        self.cancel_token.trigger()
        self.transport.close()
        # We run lots of asyncio tasks so this is to make sure they all get a chance to execute
        # and exit cleanly when they notice the cancel token has been triggered.
        await asyncio.sleep(1)

    def receive(self, address: kademlia.Address, message: bytes) -> None:
        try:
            remote_pubkey, cmd_id, payload, message_hash = _unpack(message)
        except DefectiveMessage as e:
            self.logger.error('error unpacking message (%s) from %s: %s', message, address, e)
            return

        # As of discovery version 4, expiration is the last element for all packets, so
        # we can validate that here, but if it changes we may have to do so on the
        # handler methods.
        expiration = rlp.sedes.big_endian_int.deserialize(payload[-1])
        if time.time() > expiration:
            self.logger.error('received message already expired')
            return

        cmd = CMD_ID_MAP[cmd_id]
        if len(payload) != cmd.elem_count:
            self.logger.error('invalid %s payload: %s', cmd.name, payload)
            return
        node = kademlia.Node(remote_pubkey, address)
        handler = self._get_handler(cmd)
        handler(node, payload, message_hash)

    def recv_pong(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None:
        # The pong payload should have 3 elements: to, token, expiration
        _, token, _ = payload
        self.kademlia.recv_pong(node, token)

    def recv_neighbours(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None:
        # The neighbours payload should have 2 elements: nodes, expiration
        nodes, _ = payload
        self.kademlia.recv_neighbours(node, _extract_nodes_from_payload(nodes))

    def recv_ping(self, node: kademlia.Node, _: Any, message_hash: bytes) -> None:
        self.kademlia.recv_ping(node, message_hash)

    def recv_find_node(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None:
        # The find_node payload should have 2 elements: node_id, expiration
        self.logger.debug('<<< find_node from %s', node)
        node_id, _ = payload
        self.kademlia.recv_find_node(node, big_endian_to_int(node_id))

    def send_ping(self, node: kademlia.Node) -> bytes:
        version = rlp.sedes.big_endian_int.serialize(PROTO_VERSION)
        payload = [version, self.address.to_endpoint(), node.address.to_endpoint()]
        message = _pack(CMD_PING.id, payload, self.privkey)
        self.send(node, message)
        # Return the msg hash, which is used as a token to identify pongs.
        token = message[:MAC_SIZE]
        self.logger.debug('>>> ping %s (token == %s)', node, encode_hex(token))
        return token

    def send_find_node(self, node: kademlia.Node, target_node_id: int) -> None:
        node_id = int_to_big_endian(
            target_node_id).rjust(kademlia.k_pubkey_size // 8, b'\0')
        self.logger.debug('>>> find_node to %s', node)
        message = _pack(CMD_FIND_NODE.id, [node_id], self.privkey)
        self.send(node, message)

    def send_pong(self, node: kademlia.Node, token: bytes) -> None:
        self.logger.debug('>>> pong %s', node)
        payload = [node.address.to_endpoint(), token]
        message = _pack(CMD_PONG.id, payload, self.privkey)
        self.send(node, message)

    def send_neighbours(self, node: kademlia.Node, neighbours: List[kademlia.Node]) -> None:
        nodes = []
        neighbours = sorted(neighbours)
        for n in neighbours:
            nodes.append(n.address.to_endpoint() + [n.pubkey.to_bytes()])

        max_neighbours = self._get_max_neighbours_per_packet()
        for i in range(0, len(nodes), max_neighbours):
            message = _pack(CMD_NEIGHBOURS.id, [nodes[i:i + max_neighbours]], self.privkey)
            self.logger.debug('>>> neighbours to %s: %s',
                              node, neighbours[i:i + max_neighbours])
            self.send(node, message)
Ejemplo n.º 17
0
class BasePeer(metaclass=ABCMeta):
    logger = logging.getLogger("p2p.peer.Peer")
    conn_idle_timeout = CONN_IDLE_TIMEOUT
    reply_timeout = REPLY_TIMEOUT
    # Must be defined in subclasses. All items here must be Protocol classes representing
    # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc).
    _supported_sub_protocols = []  # type: List[Type[protocol.Protocol]]
    # FIXME: Must be configurable.
    listen_port = 30303
    # Will be set upon the successful completion of a P2P handshake.
    sub_proto = None  # type: protocol.Protocol

    def __init__(self,
                 remote: Node,
                 privkey: datatypes.PrivateKey,
                 reader: asyncio.StreamReader,
                 writer: asyncio.StreamWriter,
                 aes_secret: bytes,
                 mac_secret: bytes,
                 egress_mac: sha3.keccak_256,
                 ingress_mac: sha3.keccak_256,
                 chaindb: AsyncChainDB,
                 network_id: int,
                 ) -> None:
        self._finished = asyncio.Event()
        self.remote = remote
        self.privkey = privkey
        self.reader = reader
        self.writer = writer
        self.base_protocol = P2PProtocol(self)
        self.chaindb = chaindb
        self.network_id = network_id
        self.sub_proto_msg_queue = asyncio.Queue()  # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]]  # noqa: E501
        self.cancel_token = CancelToken('Peer')

        self.egress_mac = egress_mac
        self.ingress_mac = ingress_mac
        # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
        iv = b"\x00" * 16
        aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend())
        self.aes_enc = aes_cipher.encryptor()
        self.aes_dec = aes_cipher.decryptor()
        mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend())
        self.mac_enc = mac_cipher.encryptor().update

    @abstractmethod
    async def send_sub_proto_handshake(self):
        raise NotImplementedError("Must be implemented by subclasses")

    @abstractmethod
    async def process_sub_proto_handshake(
            self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        raise NotImplementedError("Must be implemented by subclasses")

    async def do_sub_proto_handshake(self):
        """Perform the handshake for the sub-protocol agreed with the remote peer.

        Raises HandshakeFailure if the handshake is not successful.
        """
        await self.send_sub_proto_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the sub-proto handshake.
            raise HandshakeFailure(
                "{} disconnected before completing sub-proto handshake: {}".format(
                    self, msg['reason_name']))
        await self.process_sub_proto_handshake(cmd, msg)
        self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote)

    async def do_p2p_handshake(self):
        """Perform the handshake for the P2P base protocol.

        Raises HandshakeFailure if the handshake is not successful.
        """
        self.base_protocol.send_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the initial P2P handshake.
            raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
                self, msg['reason_name']))
        self.process_p2p_handshake(cmd, msg)

    async def read_sub_proto_msg(
            self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        """Read the next sub-protocol message from the queue.

        Raises OperationCancelled if the peer has been disconnected.
        """
        combined_token = self.cancel_token.chain(cancel_token)
        return await wait_with_token(self.sub_proto_msg_queue.get(), token=combined_token)

    @property
    async def genesis(self) -> BlockHeader:
        genesis_hash = await self.chaindb.coro_lookup_block_hash(GENESIS_BLOCK_NUMBER)
        return await self.chaindb.coro_get_block_header_by_hash(genesis_hash)

    @property
    async def _local_chain_info(self) -> 'ChainInfo':
        genesis = await self.genesis
        head = await self.chaindb.coro_get_canonical_head()
        total_difficulty = await self.chaindb.coro_get_score(head.hash)
        return ChainInfo(
            block_number=head.block_number,
            block_hash=head.hash,
            total_difficulty=total_difficulty,
            genesis_hash=genesis.hash,
        )

    @property
    def capabilities(self) -> List[Tuple[str, int]]:
        return [(klass.name, klass.version) for klass in self._supported_sub_protocols]

    def get_protocol_command_for(self, msg: bytes) -> protocol.Command:
        """Return the Command corresponding to the cmd_id encoded in the given msg."""
        cmd_id = get_devp2p_cmd_id(msg)
        self.logger.debug("Got msg with cmd_id: %s", cmd_id)
        if cmd_id < self.base_protocol.cmd_length:
            proto = self.base_protocol
        elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length:
            proto = self.sub_proto  # type: ignore
        else:
            raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id))
        return proto.cmd_by_id[cmd_id]

    async def read(self, n: int) -> bytes:
        self.logger.debug("Waiting for %s bytes from %s", n, self.remote)
        try:
            return await wait_with_token(
                self.reader.readexactly(n), token=self.cancel_token, timeout=self.conn_idle_timeout)
        except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError) as e:
            raise PeerConnectionLost(repr(e))

    async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None:
        try:
            await self.read_loop()
        except OperationCancelled as e:
            self.logger.debug("Peer finished: %s", e)
        except Exception:
            self.logger.exception("Unexpected error when handling remote msg")
        finally:
            self.close()
            self._finished.set()
            if finished_callback is not None:
                finished_callback(self)

    def is_finished(self) -> bool:
        return self._finished.is_set()

    async def wait_until_finished(self) -> bool:
        return await self._finished.wait()

    def close(self):
        """Close this peer's reader/writer streams.

        This will cause the peer to stop in case it is running.

        If the streams have already been closed, do nothing.
        """
        if self.reader.at_eof():
            return
        self.reader.feed_eof()
        self.writer.close()

    async def stop(self):
        """Disconnect from the remote and flag this peer as finished.

        If the peer is already flagged as finished, do nothing.
        """
        if self._finished.is_set():
            return
        self.cancel_token.trigger()
        await self._finished.wait()
        self.logger.debug("Stopped %s", self)

    async def read_loop(self):
        while True:
            try:
                cmd, msg = await self.read_msg()
            except (PeerConnectionLost, TimeoutError) as e:
                self.logger.info(
                    "%s stopped responding (%s), disconnecting", self.remote, repr(e))
                return

            try:
                self.process_msg(cmd, msg)
            except RemoteDisconnected as e:
                self.logger.info("%s disconnected: %s", self, e)
                return

    async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        header_data = await self.read(HEADER_LEN + MAC_LEN)
        header = self.decrypt_header(header_data)
        frame_size = self.get_frame_size(header)
        # The frame_size specified in the header does not include the padding to 16-byte boundary,
        # so need to do this here to ensure we read all the frame's data.
        read_size = roundup_16(frame_size)
        frame_data = await self.read(read_size + MAC_LEN)
        msg = self.decrypt_body(frame_data, frame_size)
        cmd = self.get_protocol_command_for(msg)
        loop = asyncio.get_event_loop()
        decoded_msg = await wait_with_token(
            loop.run_in_executor(None, cmd.decode, msg),
            token=self.cancel_token)
        self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg)
        return cmd, decoded_msg

    def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        """Handle the base protocol (P2P) messages."""
        if isinstance(cmd, Disconnect):
            msg = cast(Dict[str, Any], msg)
            raise RemoteDisconnected(msg['reason_name'])
        elif isinstance(cmd, Ping):
            self.base_protocol.send_pong()
        elif isinstance(cmd, Pong):
            # Currently we don't do anything when we get a pong, but eventually we should
            # update the last time we heard from a peer in our DB (which doesn't exist yet).
            pass
        else:
            raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))

    def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        self.sub_proto_msg_queue.put_nowait((cmd, msg))

    def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        if cmd.is_base_protocol:
            self.handle_p2p_msg(cmd, msg)
        else:
            self.handle_sub_proto_msg(cmd, msg)

    def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        msg = cast(Dict[str, Any], msg)
        if not isinstance(cmd, Hello):
            self.disconnect(DisconnectReason.other)
            raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
        remote_capabilities = msg['capabilities']
        try:
            self.sub_proto = self.select_sub_protocol(remote_capabilities)
        except NoMatchingPeerCapabilities:
            self.disconnect(DisconnectReason.useless_peer)
            raise HandshakeFailure(
                "No matching capabilities between us ({}) and {} ({}), disconnecting".format(
                    self.capabilities, self.remote, remote_capabilities))
        self.logger.debug(
            "Finished P2P handshake with %s, using sub-protocol %s",
            self.remote, self.sub_proto)

    def encrypt(self, header: bytes, frame: bytes) -> bytes:
        if len(header) != HEADER_LEN:
            raise ValueError("Unexpected header length: {}".format(len(header)))

        header_ciphertext = self.aes_enc.update(header)
        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext))
        header_mac = self.egress_mac.digest()[:HEADER_LEN]

        frame_ciphertext = self.aes_enc.update(frame)
        self.egress_mac.update(frame_ciphertext)
        fmac_seed = self.egress_mac.digest()[:HEADER_LEN]

        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed))
        frame_mac = self.egress_mac.digest()[:HEADER_LEN]

        return header_ciphertext + header_mac + frame_ciphertext + frame_mac

    def decrypt_header(self, data: bytes) -> bytes:
        if len(data) != HEADER_LEN + MAC_LEN:
            raise ValueError("Unexpected header length: {}".format(len(data)))

        header_ciphertext = data[:HEADER_LEN]
        header_mac = data[HEADER_LEN:]
        mac_secret = self.ingress_mac.digest()[:HEADER_LEN]
        aes = self.mac_enc(mac_secret)[:HEADER_LEN]
        self.ingress_mac.update(sxor(aes, header_ciphertext))
        expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN]
        if not bytes_eq(expected_header_mac, header_mac):
            raise AuthenticationError('Invalid header mac')
        return self.aes_dec.update(header_ciphertext)

    def decrypt_body(self, data: bytes, body_size: int) -> bytes:
        read_size = roundup_16(body_size)
        if len(data) < read_size + MAC_LEN:
            raise ValueError('Insufficient body length; Got {}, wanted {}'.format(
                len(data), (read_size + MAC_LEN)))

        frame_ciphertext = data[:read_size]
        frame_mac = data[read_size:read_size + MAC_LEN]

        self.ingress_mac.update(frame_ciphertext)
        fmac_seed = self.ingress_mac.digest()[:MAC_LEN]
        self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed))
        expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN]
        if not bytes_eq(expected_frame_mac, frame_mac):
            raise AuthenticationError('Invalid frame mac')
        return self.aes_dec.update(frame_ciphertext)[:body_size]

    def get_frame_size(self, header: bytes) -> int:
        # The frame size is encoded in the header as a 3-byte int, so before we unpack we need
        # to prefix it with an extra byte.
        encoded_size = b'\x00' + header[:3]
        (size,) = struct.unpack(b'>I', encoded_size)
        return size

    def send(self, header: bytes, body: bytes) -> None:
        cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int)
        self.logger.debug("Sending msg with cmd_id: %s", cmd_id)
        self.writer.write(self.encrypt(header, body))

    def disconnect(self, reason: DisconnectReason) -> None:
        """Send a disconnect msg to the remote node and stop this Peer.

        :param reason: An item from the DisconnectReason enum.
        """
        if not isinstance(reason, DisconnectReason):
            self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value)
            raise ValueError(
                "Reason must be an item of DisconnectReason, got {}".format(reason))
        self.base_protocol.send_disconnect(reason.value)
        self.close()

    def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
                            ) -> protocol.Protocol:
        """Select the sub-protocol to use when talking to the remote.

        Find the highest version of our supported sub-protocols that is also supported by the
        remote and stores an instance of it (with the appropriate cmd_id offset) in
        self.sub_proto.

        Raises NoMatchingPeerCapabilities if none of our supported protocols match one of the
        remote's protocols.
        """
        matching_capabilities = set(self.capabilities).intersection(remote_capabilities)
        if not matching_capabilities:
            raise NoMatchingPeerCapabilities()
        _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1))
        offset = self.base_protocol.cmd_length
        for proto_class in self._supported_sub_protocols:
            if proto_class.version == highest_matching_version:
                return proto_class(self, offset)
        raise NoMatchingPeerCapabilities()

    def __str__(self):
        return "{} {}".format(self.__class__.__name__, self.remote)
Ejemplo n.º 18
0
class StateDownloader(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.state.StateDownloader")
    _pending_nodes = {}  # type: Dict[Any, float]
    _total_processed_nodes = 0
    _report_interval = 10  # Number of seconds between progress reports.
    _reply_timeout = 20  # seconds
    _start_time = None  # type: float
    _total_timeouts = 0

    def __init__(self,
                 state_db: BaseDB,
                 root_hash: bytes,
                 peer_pool: PeerPool,
                 token: CancelToken = None) -> None:
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.root_hash = root_hash
        self.scheduler = StateSync(root_hash, state_db)
        self._running_peers = set()  # type: Set[ETHPeer]
        self._peers_with_pending_requests = {}  # type: Dict[ETHPeer, float]
        self.cancel_token = CancelToken('StateDownloader')
        if token is not None:
            self.cancel_token = self.cancel_token.chain(token)

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))

    @property
    def idle_peers(self) -> List[ETHPeer]:
        peers = set([cast(ETHPeer, peer) for peer in self.peer_pool.peers])
        return list(peers.difference(self._peers_with_pending_requests))

    async def get_idle_peer(self) -> ETHPeer:
        while not self.idle_peers:
            self.logger.debug("Waiting for an idle peer...")
            await wait_with_token(asyncio.sleep(0.02), token=self.cancel_token)
        return secrets.choice(self.idle_peers)

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break

            # Run self._handle_msg() with ensure_future() instead of awaiting for it so that we
            # can keep consuming msgs while _handle_msg() performs cpu-intensive tasks in separate
            # processes.
            asyncio.ensure_future(self._handle_msg(peer, cmd, msg))

    async def _handle_msg(
            self, peer: ETHPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        loop = asyncio.get_event_loop()
        if isinstance(cmd, eth.NodeData):
            self.logger.debug("Got %d NodeData entries from %s", len(msg), peer)

            # Check before we remove because sometimes a reply may come after our timeout and in
            # that case we won't be expecting it anymore.
            if peer in self._peers_with_pending_requests:
                self._peers_with_pending_requests.pop(peer)

            node_keys = await loop.run_in_executor(None, list, map(keccak, msg))
            for node_key, node in zip(node_keys, msg):
                self._total_processed_nodes += 1
                try:
                    self.scheduler.process([(node_key, node)])
                except SyncRequestAlreadyProcessed:
                    # This means we received a node more than once, which can happen when we
                    # retry after a timeout.
                    pass
                # A node may be received more than once, so pop() with a default value.
                self._pending_nodes.pop(node_key, None)
        else:
            # We ignore everything that is not a NodeData when doing a StateSync.
            self.logger.debug("Ignoring %s msg while doing a StateSync", cmd)

    async def stop(self):
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        while self._running_peers:
            self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers))
            await asyncio.sleep(0.1)

    async def request_nodes(self, node_keys: List[bytes]) -> None:
        batches = list(partition_all(eth.MAX_STATE_FETCH, node_keys))
        for batch in batches:
            peer = await self.get_idle_peer()
            now = time.time()
            for node_key in batch:
                self._pending_nodes[node_key] = now
            self.logger.debug("Requesting %d trie nodes to %s", len(batch), peer)
            peer.sub_proto.send_get_node_data(batch)
            self._peers_with_pending_requests[peer] = now

    async def _periodically_retry_timedout(self):
        while True:
            now = time.time()
            # First, update our list of peers with pending requests by removing those for which a
            # request timed out. This loop mutates the dict, so we iterate on a copy of it.
            for peer, last_req_time in list(self._peers_with_pending_requests.items()):
                if now - last_req_time > self._reply_timeout:
                    self._peers_with_pending_requests.pop(peer)

            # Now re-send requests for nodes that timed out.
            oldest_request_time = now
            timed_out = []
            for node_key, req_time in self._pending_nodes.items():
                if now - req_time > self._reply_timeout:
                    timed_out.append(node_key)
                elif req_time < oldest_request_time:
                    oldest_request_time = req_time
            if timed_out:
                self.logger.debug("Re-requesting %d trie nodes", len(timed_out))
                self._total_timeouts += len(timed_out)
                try:
                    await self.request_nodes(timed_out)
                except OperationCancelled:
                    break

            # Finally, sleep until the time our oldest request is scheduled to timeout.
            now = time.time()
            sleep_duration = (oldest_request_time + self._reply_timeout) - now
            try:
                await wait_with_token(asyncio.sleep(sleep_duration), token=self.cancel_token)
            except OperationCancelled:
                break

    async def run(self):
        """Fetch all trie nodes starting from self.root_hash, and store them in self.db.

        Raises OperationCancelled if we're interrupted before that is completed.
        """
        self._start_time = time.time()
        self.logger.info("Starting state sync for root hash %s", encode_hex(self.root_hash))
        asyncio.ensure_future(self._periodically_report_progress())
        asyncio.ensure_future(self._periodically_retry_timedout())
        while self.scheduler.has_pending_requests:
            # This ensures we yield control and give _handle_msg() a chance to process any nodes
            # we may have received already, also ensuring we exit when our cancel token is
            # triggered.
            await wait_with_token(asyncio.sleep(0), token=self.cancel_token)

            requests = self.scheduler.next_batch(eth.MAX_STATE_FETCH)
            if not requests:
                # Although we frequently yield control above, to let our msg handler process
                # received nodes (scheduling new requests), there may be cases when the
                # pending nodes take a while to arrive thus causing the scheduler to run out
                # of new requests for a while.
                self.logger.info("Scheduler queue is empty, sleeping a bit")
                await wait_with_token(asyncio.sleep(0.5), token=self.cancel_token)
                continue

            await self.request_nodes([request.node_key for request in requests])

        self.logger.info("Finished state sync with root hash %s", encode_hex(self.root_hash))

    async def _periodically_report_progress(self):
        while True:
            now = time.time()
            self.logger.info("====== State sync progress ========")
            self.logger.info("Nodes processed: %d", self._total_processed_nodes)
            self.logger.info("Nodes processed per second (average): %d",
                             self._total_processed_nodes / (now - self._start_time))
            self.logger.info("Nodes committed to DB: %d", self.scheduler.committed_nodes)
            self.logger.info(
                "Nodes requested but not received yet: %d", len(self._pending_nodes))
            self.logger.info(
                "Nodes scheduled but not requested yet: %d", len(self.scheduler.requests))
            self.logger.info("Total nodes timed out: %d", self._total_timeouts)
            try:
                await wait_with_token(asyncio.sleep(self._report_interval), token=self.cancel_token)
            except OperationCancelled:
                break
Ejemplo n.º 19
0
class ShardSyncer(BaseService, PeerPoolSubscriber):
    logger = logging.getLogger("p2p.sharding.ShardSyncer")

    def __init__(self, shard: Shard, peer_pool: PeerPool,
                 token: CancelToken) -> None:
        super().__init__(token)

        self.shard = shard
        self.peer_pool = peer_pool

        self.incoming_collation_queue: asyncio.Queue[
            Collation] = asyncio.Queue()

        self.collations_received_event = asyncio.Event()

        self.cancel_token = CancelToken("ShardSyncer")
        if token is not None:
            self.cancel_token = self.cancel_token.chain(token)

        self.start_time = time.time()

    async def _run(self) -> None:
        self.peer_pool.subscribe(self)
        while True:
            collation = await wait_with_token(
                self.incoming_collation_queue.get(), token=self.cancel_token)

            if collation.shard_id != self.shard.shard_id:
                self.logger.debug(
                    "Ignoring received collation belonging to wrong shard")
                continue
            if self.shard.get_availability(
                    collation.header) is Availability.AVAILABLE:
                self.logger.debug("Ignoring already available collation")
                continue

            self.logger.debug("Adding collation {} to shard".format(collation))
            self.shard.add_collation(collation)
            for peer in self.peer_pool.peers:
                cast(ShardingPeer, peer).send_collations([collation])

            self.collations_received_event.set()
            self.collations_received_event.clear()

    async def _cleanup(self) -> None:
        self.peer_pool.unsubscribe(self)

    def propose(self) -> Collation:
        """Broadcast a new collation to the network, add it to the local shard, and return it."""
        # create collation for current period
        period = self.get_current_period()
        body = zpad_right(str(self).encode("utf-8"), COLLATION_SIZE)
        header = CollationHeader(self.shard.shard_id, calc_chunk_root(body),
                                 period, b"\x11" * 20)
        collation = Collation(header, body)

        self.logger.debug("Proposing collation {}".format(collation))

        # add collation to local chain
        self.shard.add_collation(collation)

        # broadcast collation
        for peer in self.peer_pool.peers:
            cast(ShardingPeer, peer).send_collations([collation])

        return collation

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ShardingPeer, peer)))

    async def handle_peer(self, peer: ShardingPeer) -> None:
        while True:
            try:
                collation = await wait_with_token(
                    peer.incoming_collation_queue.get(),
                    token=self.cancel_token)
                await wait_with_token(
                    self.incoming_collation_queue.put(collation),
                    token=self.cancel_token)
            except OperationCancelled:
                break  # stop handling peer if cancel token is triggered

    def get_current_period(self):
        # TODO: get this from main chain
        return int((time.time() - self.start_time) // COLLATION_PERIOD)