Ejemplo n.º 1
0
class FullNodeSyncer:
    logger = logging.getLogger("p2p.sync.FullNodeSyncer")

    def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool) -> None:
        self.chaindb = chaindb
        self.peer_pool = peer_pool
        self.cancel_token = CancelToken('FullNodeSyncer')

    async def run(self) -> None:
        # Fast-sync chain data.
        chain_syncer = ChainSyncer(self.chaindb, self.peer_pool,
                                   self.cancel_token)
        try:
            await chain_syncer.run()
        finally:
            await chain_syncer.stop()

        # Download state for our current head.
        head = self.chaindb.get_canonical_head()
        downloader = StateDownloader(self.chaindb.db, head.state_root,
                                     self.peer_pool, self.cancel_token)
        try:
            await downloader.run()
        finally:
            await downloader.stop()

        # TODO: Run the regular sync.

    async def stop(self):
        self.cancel_token.trigger()
Ejemplo n.º 2
0
async def test_wait_cancel_pending_tasks_on_completion(event_loop):
    token = CancelToken('token')
    token2 = CancelToken('token2')
    chain = token.chain(token2)
    token2.trigger()
    await chain.wait()
    await assert_only_current_task_not_done()
Ejemplo n.º 3
0
class IPCServer:
    logger = logging.getLogger('trinity.rpc.ipc.IPCServer')

    cancel_token = None
    ipc_path = None
    rpc = None
    server = None

    def __init__(self, rpc, ipc_path):
        self.rpc = rpc
        self.ipc_path = ipc_path
        self.cancel_token = CancelToken('IPCServer')

    async def run(self, loop=None):
        self.server = await asyncio.start_unix_server(
            connection_handler(self.rpc.execute, self.cancel_token),
            self.ipc_path,
            loop=loop,
            limit=MAXIMUM_REQUEST_BYTES,
        )
        self.logger.info('ipc-path: %s', os.path.abspath(self.ipc_path))
        await self.cancel_token.wait()

    async def stop(self):
        self.cancel_token.trigger()
        self.server.close()
        await self.server.wait_closed()
Ejemplo n.º 4
0
async def test_wait_cancel_pending_tasks_on_completion(event_loop):
    token = CancelToken('token')
    token2 = CancelToken('token2')
    chain = token.chain(token2)
    token2.trigger()
    await chain.wait()
    await assert_only_current_task_not_done()
Ejemplo n.º 5
0
class IPCServer:
    logger = logging.getLogger('trinity.rpc.ipc.IPCServer')

    cancel_token = None
    ipc_path = None
    rpc = None
    server = None

    def __init__(self, rpc: RPCServer, ipc_path: pathlib.Path) -> None:
        self.rpc = rpc
        self.ipc_path = ipc_path

    async def run(self, loop: asyncio.AbstractEventLoop=None) -> None:
        ipc_path = str(self.ipc_path)
        self.cancel_token = CancelToken('IPCServer', loop=loop)
        self.server = await asyncio.start_unix_server(
            connection_handler(self.rpc.execute, self.cancel_token),
            ipc_path,
            loop=loop,
            limit=MAXIMUM_REQUEST_BYTES,
        )
        self.logger.info('IPC started at: %s', os.path.abspath(ipc_path))
        await self.cancel_token.wait()

    async def stop(self) -> None:
        if self.cancel_token is not None:
            self.cancel_token.trigger()
        self.server.close()
        await self.server.wait_closed()
Ejemplo n.º 6
0
def test_token_chain_trigger_last():
    token = CancelToken('token')
    token2 = CancelToken('token2')
    token3 = CancelToken('token3')
    chain = token.chain(token2).chain(token3)
    assert not chain.triggered
    token3.trigger()
    assert chain.triggered
    assert chain.triggered_token == token3
Ejemplo n.º 7
0
def test_token_chain_trigger_last():
    token = CancelToken('token')
    token2 = CancelToken('token2')
    token3 = CancelToken('token3')
    chain = token.chain(token2).chain(token3)
    assert not chain.triggered
    token3.trigger()
    assert chain.triggered
    assert chain.triggered_token == token3
Ejemplo n.º 8
0
async def test_shard_syncer(n_peers, connections):
    cancel_token = CancelToken("canceltoken")

    PeerTuple = collections.namedtuple("PeerTuple",
                                       ["node", "server", "syncer"])
    peer_tuples = []

    for i in range(n_peers):
        private_key = keys.PrivateKey(pad32(int_to_big_endian(i + 1)))
        port = get_open_port()
        address = Address("127.0.0.1", port, port)

        node = Node(private_key.public_key, address)

        server = ShardingServer(private_key,
                                address,
                                network_id=9324090483,
                                min_peers=0,
                                peer_class=ShardingPeer)
        asyncio.ensure_future(server.run())

        peer_tuples.append(
            PeerTuple(
                node=node,
                server=server,
                syncer=server.syncer,
            ))

    # connect peers to each other
    await asyncio.gather(*[
        peer_tuples[i].server.peer_pool._connect_to_nodes(
            [peer_tuples[j].node]) for i, j in connections
    ])
    for i, j in connections:
        peer_remotes = [
            peer.remote for peer in peer_tuples[i].server.peer_pool.peers
        ]
        assert peer_tuples[j].node in peer_remotes

    # let each node propose and check that collation appears at all other nodes
    for proposer in peer_tuples:
        collation = proposer.syncer.propose()
        await asyncio.wait_for(asyncio.gather(*[
            peer_tuple.syncer.collations_received_event.wait()
            for peer_tuple in peer_tuples if peer_tuple != proposer
        ]),
                               timeout=10)
        for peer_tuple in peer_tuples:
            assert peer_tuple.syncer.shard.get_collation_by_hash(
                collation.hash) == collation

    # stop everything
    cancel_token.trigger()
    await asyncio.gather(
        *[peer_tuple.server.cancel() for peer_tuple in peer_tuples])
    await asyncio.gather(
        *[peer_tuple.syncer.cancel() for peer_tuple in peer_tuples])
Ejemplo n.º 9
0
class IPCServer:
    cancel_token = None
    ipc_path = None
    rpc = None
    server = None

    def __init__(self, rpc, ipc_path):
        self.rpc = rpc
        self.ipc_path = ipc_path
        self.cancel_token = CancelToken('IPCServer')

    async def run(self, loop=None):
        self.server = await asyncio.start_unix_server(
            connection_handler(self.rpc.execute, self.cancel_token),
            self.ipc_path,
            loop=loop,
            limit=MAXIMUM_REQUEST_BYTES,
        )
        await self.cancel_token.wait()

    async def stop(self):
        self.cancel_token.trigger()
        self.server.close()
        await self.server.wait_closed()
Ejemplo n.º 10
0
class BasePeer(metaclass=ABCMeta):
    logger = logging.getLogger("p2p.peer.Peer")
    conn_idle_timeout = CONN_IDLE_TIMEOUT
    reply_timeout = REPLY_TIMEOUT
    # Must be defined in subclasses. All items here must be Protocol classes representing
    # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc).
    _supported_sub_protocols = []  # type: List[Type[protocol.Protocol]]
    # FIXME: Must be configurable.
    listen_port = 30303
    # Will be set upon the successful completion of a P2P handshake.
    sub_proto = None  # type: protocol.Protocol

    def __init__(self,
                 remote: Node,
                 privkey: datatypes.PrivateKey,
                 reader: asyncio.StreamReader,
                 writer: asyncio.StreamWriter,
                 aes_secret: bytes,
                 mac_secret: bytes,
                 egress_mac: sha3.keccak_256,
                 ingress_mac: sha3.keccak_256,
                 chaindb: AsyncChainDB,
                 network_id: int,
                 ) -> None:
        self._finished = asyncio.Event()
        self.remote = remote
        self.privkey = privkey
        self.reader = reader
        self.writer = writer
        self.base_protocol = P2PProtocol(self)
        self.chaindb = chaindb
        self.network_id = network_id
        self.sub_proto_msg_queue = asyncio.Queue()  # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]]  # noqa: E501
        self.cancel_token = CancelToken('Peer')

        self.egress_mac = egress_mac
        self.ingress_mac = ingress_mac
        # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
        iv = b"\x00" * 16
        aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend())
        self.aes_enc = aes_cipher.encryptor()
        self.aes_dec = aes_cipher.decryptor()
        mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend())
        self.mac_enc = mac_cipher.encryptor().update

    @abstractmethod
    async def send_sub_proto_handshake(self):
        raise NotImplementedError("Must be implemented by subclasses")

    @abstractmethod
    async def process_sub_proto_handshake(
            self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        raise NotImplementedError("Must be implemented by subclasses")

    async def do_sub_proto_handshake(self):
        """Perform the handshake for the sub-protocol agreed with the remote peer.

        Raises HandshakeFailure if the handshake is not successful.
        """
        await self.send_sub_proto_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the sub-proto handshake.
            raise HandshakeFailure(
                "{} disconnected before completing sub-proto handshake: {}".format(
                    self, msg['reason_name']))
        await self.process_sub_proto_handshake(cmd, msg)
        self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote)

    async def do_p2p_handshake(self):
        """Perform the handshake for the P2P base protocol.

        Raises HandshakeFailure if the handshake is not successful.
        """
        self.base_protocol.send_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the initial P2P handshake.
            raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
                self, msg['reason_name']))
        self.process_p2p_handshake(cmd, msg)

    async def read_sub_proto_msg(
            self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        """Read the next sub-protocol message from the queue.

        Raises OperationCancelled if the peer has been disconnected.
        """
        combined_token = self.cancel_token.chain(cancel_token)
        return await wait_with_token(self.sub_proto_msg_queue.get(), token=combined_token)

    @property
    async def genesis(self) -> BlockHeader:
        genesis_hash = await self.chaindb.coro_lookup_block_hash(GENESIS_BLOCK_NUMBER)
        return await self.chaindb.coro_get_block_header_by_hash(genesis_hash)

    @property
    async def _local_chain_info(self) -> 'ChainInfo':
        genesis = await self.genesis
        head = await self.chaindb.coro_get_canonical_head()
        total_difficulty = await self.chaindb.coro_get_score(head.hash)
        return ChainInfo(
            block_number=head.block_number,
            block_hash=head.hash,
            total_difficulty=total_difficulty,
            genesis_hash=genesis.hash,
        )

    @property
    def capabilities(self) -> List[Tuple[str, int]]:
        return [(klass.name, klass.version) for klass in self._supported_sub_protocols]

    def get_protocol_command_for(self, msg: bytes) -> protocol.Command:
        """Return the Command corresponding to the cmd_id encoded in the given msg."""
        cmd_id = get_devp2p_cmd_id(msg)
        self.logger.debug("Got msg with cmd_id: %s", cmd_id)
        if cmd_id < self.base_protocol.cmd_length:
            proto = self.base_protocol
        elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length:
            proto = self.sub_proto  # type: ignore
        else:
            raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id))
        return proto.cmd_by_id[cmd_id]

    async def read(self, n: int) -> bytes:
        self.logger.debug("Waiting for %s bytes from %s", n, self.remote)
        try:
            return await wait_with_token(
                self.reader.readexactly(n), token=self.cancel_token, timeout=self.conn_idle_timeout)
        except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError) as e:
            raise PeerConnectionLost(repr(e))

    async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None:
        try:
            await self.read_loop()
        except OperationCancelled as e:
            self.logger.debug("Peer finished: %s", e)
        except Exception:
            self.logger.exception("Unexpected error when handling remote msg")
        finally:
            self.close()
            self._finished.set()
            if finished_callback is not None:
                finished_callback(self)

    def is_finished(self) -> bool:
        return self._finished.is_set()

    async def wait_until_finished(self) -> bool:
        return await self._finished.wait()

    def close(self):
        """Close this peer's reader/writer streams.

        This will cause the peer to stop in case it is running.

        If the streams have already been closed, do nothing.
        """
        if self.reader.at_eof():
            return
        self.reader.feed_eof()
        self.writer.close()

    async def stop(self):
        """Disconnect from the remote and flag this peer as finished.

        If the peer is already flagged as finished, do nothing.
        """
        if self._finished.is_set():
            return
        self.cancel_token.trigger()
        await self._finished.wait()
        self.logger.debug("Stopped %s", self)

    async def read_loop(self):
        while True:
            try:
                cmd, msg = await self.read_msg()
            except (PeerConnectionLost, TimeoutError) as e:
                self.logger.info(
                    "%s stopped responding (%s), disconnecting", self.remote, repr(e))
                return

            try:
                self.process_msg(cmd, msg)
            except RemoteDisconnected as e:
                self.logger.info("%s disconnected: %s", self, e)
                return

    async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        header_data = await self.read(HEADER_LEN + MAC_LEN)
        header = self.decrypt_header(header_data)
        frame_size = self.get_frame_size(header)
        # The frame_size specified in the header does not include the padding to 16-byte boundary,
        # so need to do this here to ensure we read all the frame's data.
        read_size = roundup_16(frame_size)
        frame_data = await self.read(read_size + MAC_LEN)
        msg = self.decrypt_body(frame_data, frame_size)
        cmd = self.get_protocol_command_for(msg)
        loop = asyncio.get_event_loop()
        decoded_msg = await wait_with_token(
            loop.run_in_executor(None, cmd.decode, msg),
            token=self.cancel_token)
        self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg)
        return cmd, decoded_msg

    def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        """Handle the base protocol (P2P) messages."""
        if isinstance(cmd, Disconnect):
            msg = cast(Dict[str, Any], msg)
            raise RemoteDisconnected(msg['reason_name'])
        elif isinstance(cmd, Ping):
            self.base_protocol.send_pong()
        elif isinstance(cmd, Pong):
            # Currently we don't do anything when we get a pong, but eventually we should
            # update the last time we heard from a peer in our DB (which doesn't exist yet).
            pass
        else:
            raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))

    def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        self.sub_proto_msg_queue.put_nowait((cmd, msg))

    def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        if cmd.is_base_protocol:
            self.handle_p2p_msg(cmd, msg)
        else:
            self.handle_sub_proto_msg(cmd, msg)

    def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        msg = cast(Dict[str, Any], msg)
        if not isinstance(cmd, Hello):
            self.disconnect(DisconnectReason.other)
            raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
        remote_capabilities = msg['capabilities']
        try:
            self.sub_proto = self.select_sub_protocol(remote_capabilities)
        except NoMatchingPeerCapabilities:
            self.disconnect(DisconnectReason.useless_peer)
            raise HandshakeFailure(
                "No matching capabilities between us ({}) and {} ({}), disconnecting".format(
                    self.capabilities, self.remote, remote_capabilities))
        self.logger.debug(
            "Finished P2P handshake with %s, using sub-protocol %s",
            self.remote, self.sub_proto)

    def encrypt(self, header: bytes, frame: bytes) -> bytes:
        if len(header) != HEADER_LEN:
            raise ValueError("Unexpected header length: {}".format(len(header)))

        header_ciphertext = self.aes_enc.update(header)
        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext))
        header_mac = self.egress_mac.digest()[:HEADER_LEN]

        frame_ciphertext = self.aes_enc.update(frame)
        self.egress_mac.update(frame_ciphertext)
        fmac_seed = self.egress_mac.digest()[:HEADER_LEN]

        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed))
        frame_mac = self.egress_mac.digest()[:HEADER_LEN]

        return header_ciphertext + header_mac + frame_ciphertext + frame_mac

    def decrypt_header(self, data: bytes) -> bytes:
        if len(data) != HEADER_LEN + MAC_LEN:
            raise ValueError("Unexpected header length: {}".format(len(data)))

        header_ciphertext = data[:HEADER_LEN]
        header_mac = data[HEADER_LEN:]
        mac_secret = self.ingress_mac.digest()[:HEADER_LEN]
        aes = self.mac_enc(mac_secret)[:HEADER_LEN]
        self.ingress_mac.update(sxor(aes, header_ciphertext))
        expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN]
        if not bytes_eq(expected_header_mac, header_mac):
            raise AuthenticationError('Invalid header mac')
        return self.aes_dec.update(header_ciphertext)

    def decrypt_body(self, data: bytes, body_size: int) -> bytes:
        read_size = roundup_16(body_size)
        if len(data) < read_size + MAC_LEN:
            raise ValueError('Insufficient body length; Got {}, wanted {}'.format(
                len(data), (read_size + MAC_LEN)))

        frame_ciphertext = data[:read_size]
        frame_mac = data[read_size:read_size + MAC_LEN]

        self.ingress_mac.update(frame_ciphertext)
        fmac_seed = self.ingress_mac.digest()[:MAC_LEN]
        self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed))
        expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN]
        if not bytes_eq(expected_frame_mac, frame_mac):
            raise AuthenticationError('Invalid frame mac')
        return self.aes_dec.update(frame_ciphertext)[:body_size]

    def get_frame_size(self, header: bytes) -> int:
        # The frame size is encoded in the header as a 3-byte int, so before we unpack we need
        # to prefix it with an extra byte.
        encoded_size = b'\x00' + header[:3]
        (size,) = struct.unpack(b'>I', encoded_size)
        return size

    def send(self, header: bytes, body: bytes) -> None:
        cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int)
        self.logger.debug("Sending msg with cmd_id: %s", cmd_id)
        self.writer.write(self.encrypt(header, body))

    def disconnect(self, reason: DisconnectReason) -> None:
        """Send a disconnect msg to the remote node and stop this Peer.

        :param reason: An item from the DisconnectReason enum.
        """
        if not isinstance(reason, DisconnectReason):
            self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value)
            raise ValueError(
                "Reason must be an item of DisconnectReason, got {}".format(reason))
        self.base_protocol.send_disconnect(reason.value)
        self.close()

    def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
                            ) -> protocol.Protocol:
        """Select the sub-protocol to use when talking to the remote.

        Find the highest version of our supported sub-protocols that is also supported by the
        remote and stores an instance of it (with the appropriate cmd_id offset) in
        self.sub_proto.

        Raises NoMatchingPeerCapabilities if none of our supported protocols match one of the
        remote's protocols.
        """
        matching_capabilities = set(self.capabilities).intersection(remote_capabilities)
        if not matching_capabilities:
            raise NoMatchingPeerCapabilities()
        _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1))
        offset = self.base_protocol.cmd_length
        for proto_class in self._supported_sub_protocols:
            if proto_class.version == highest_matching_version:
                return proto_class(self, offset)
        raise NoMatchingPeerCapabilities()

    def __str__(self):
        return "{} {}".format(self.__class__.__name__, self.remote)
Ejemplo n.º 11
0
class PeerPool:
    """PeerPool attempts to keep connections to at least min_peers on the given network."""
    logger = logging.getLogger("p2p.peer.PeerPool")
    _connect_loop_sleep = 2
    _last_lookup = 0  # type: float
    _lookup_interval = 5  # type: int

    def __init__(self,
                 peer_class: Type[BasePeer],
                 chaindb: AsyncChainDB,
                 network_id: int,
                 privkey: datatypes.PrivateKey,
                 discovery: DiscoveryProtocol,
                 min_peers: int = DEFAULT_MIN_PEERS,
                 ) -> None:
        self.peer_class = peer_class
        self.chaindb = chaindb
        self.network_id = network_id
        self.privkey = privkey
        self.discovery = discovery
        self.min_peers = min_peers
        self.connected_nodes = {}  # type: Dict[Node, BasePeer]
        self.cancel_token = CancelToken('PeerPool')
        self._subscribers = []  # type: List[PeerPoolSubscriber]

    def get_nodes_to_connect(self) -> Generator[Node, None, None]:
        return self.discovery.get_random_nodes(self.min_peers)

    def subscribe(self, subscriber: PeerPoolSubscriber) -> None:
        self._subscribers.append(subscriber)
        for peer in self.connected_nodes.values():
            subscriber.register_peer(peer)

    def unsubscribe(self, subscriber: PeerPoolSubscriber) -> None:
        if subscriber in self._subscribers:
            self._subscribers.remove(subscriber)

    def start_peer(self, peer):
        asyncio.ensure_future(peer.run(finished_callback=self._peer_finished))
        self.add_peer(peer)

    def add_peer(self, peer):
        self.logger.debug('Adding peer (%s) ...', str(peer))
        self.connected_nodes[peer.remote] = peer
        self.logger.debug('Number of peers: %d', len(self.connected_nodes))
        for subscriber in self._subscribers:
            subscriber.register_peer(peer)

    async def run(self) -> None:
        self.logger.info("Running PeerPool...")
        while not self.cancel_token.triggered:
            try:
                await self.maybe_connect_to_more_peers()
            except OperationCancelled as e:
                self.logger.debug("PeerPool finished: %s", e)
                break
            except:  # noqa: E722
                # Most unexpected errors should be transient, so we log and restart from scratch.
                self.logger.exception("Unexpected error, restarting")
                await self.stop_all_peers()
            # Wait self._connect_loop_sleep seconds, unless we're asked to stop.
            await asyncio.wait([self.cancel_token.wait()], timeout=self._connect_loop_sleep)

    async def stop_all_peers(self) -> None:
        self.logger.info("Stopping all peers ...")
        await asyncio.gather(
            *[peer.stop() for peer in self.connected_nodes.values()])

    async def stop(self) -> None:
        self.cancel_token.trigger()
        await self.stop_all_peers()

    async def connect(self, remote: Node) -> BasePeer:
        """
        Connect to the given remote and return a Peer instance when successful.
        Returns None if the remote is unreachable, times out or is useless.
        """
        if remote in self.connected_nodes:
            self.logger.debug("Skipping %s; already connected to it", remote)
            return None
        expected_exceptions = (
            UnreachablePeer, TimeoutError, PeerConnectionLost, HandshakeFailure)
        try:
            self.logger.debug("Connecting to %s...", remote)
            peer = await wait_with_token(
                handshake(remote, self.privkey, self.peer_class, self.chaindb, self.network_id),
                token=self.cancel_token,
                timeout=HANDSHAKE_TIMEOUT)
            return peer
        except OperationCancelled:
            # Pass it on to instruct our main loop to stop.
            raise
        except expected_exceptions as e:
            self.logger.debug("Could not complete handshake with %s: %s", remote, repr(e))
        except Exception:
            self.logger.exception("Unexpected error during auth/p2p handshake with %s", remote)
        return None

    async def lookup_random_node(self) -> None:
        # This method runs in the background, so we must catch OperationCancelled here otherwise
        # asyncio will warn that its exception was never retrieved.
        try:
            await self.discovery.lookup_random(self.cancel_token)
        except OperationCancelled:
            pass
        self._last_lookup = time.time()

    async def maybe_connect_to_more_peers(self) -> None:
        """Connect to more peers if we're not yet connected to at least self.min_peers."""
        if len(self.connected_nodes) >= self.min_peers:
            self.logger.debug(
                "Already connected to %s peers: %s; sleeping",
                len(self.connected_nodes),
                [remote for remote in self.connected_nodes])
            return

        if self._last_lookup + self._lookup_interval < time.time():
            asyncio.ensure_future(self.lookup_random_node())

        await self._connect_to_nodes(self.get_nodes_to_connect())

        # In some cases (e.g ROPSTEN or private testnets), the discovery table might be full of
        # bad peers so if we can't connect to any peers we try a random bootstrap node as well.
        if not self.peers:
            await self._connect_to_nodes(self._get_random_bootnode())

    def _get_random_bootnode(self) -> Generator[Node, None, None]:
        if self.discovery.bootstrap_nodes:
            yield random.choice(self.discovery.bootstrap_nodes)
        else:
            self.logger.warning('No bootstrap_nodes')

    async def _connect_to_nodes(self, nodes: Generator[Node, None, None]) -> None:
        for node in nodes:
            # TODO: Consider changing connect() to raise an exception instead of returning None,
            # as discussed in
            # https://github.com/ethereum/py-evm/pull/139#discussion_r152067425
            peer = await self.connect(node)
            if peer is not None:
                self.logger.info("Successfully connected to %s", peer)
                self.start_peer(peer)
                if len(self.connected_nodes) >= self.min_peers:
                    return

    def _peer_finished(self, peer: BasePeer) -> None:
        """Remove the given peer from our list of connected nodes.
        This is passed as a callback to be called when a peer finishes.
        """
        if peer.remote in self.connected_nodes:
            self.connected_nodes.pop(peer.remote)

    @property
    def peers(self) -> List[BasePeer]:
        peers = list(self.connected_nodes.values())
        # Shuffle the list of peers so that dumb callsites are less likely to send all requests to
        # a single peer even if they always pick the first one from the list.
        random.shuffle(peers)
        return peers

    async def get_random_peer(self) -> BasePeer:
        while not self.peers:
            self.logger.debug("No connected peers, sleeping a bit")
            await asyncio.sleep(0.5)
        return random.choice(self.peers)
Ejemplo n.º 12
0
class LightChain(Chain, PeerPoolSubscriber):
    logger = logging.getLogger("p2p.lightchain.LightChain")
    max_consecutive_timeouts = 5

    def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool) -> None:
        super(LightChain, self).__init__(chaindb)
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self._announcement_queue = asyncio.Queue()  # type: asyncio.Queue[Tuple[LESPeer, les.HeadInfo]]  # noqa: E501
        self._last_processed_announcements = {}  # type: Dict[LESPeer, les.HeadInfo]
        self.cancel_token = CancelToken('LightChain')
        self._running_peers = set()  # type: Set[LESPeer]

    @classmethod
    def from_genesis_header(cls, chaindb, genesis_header, peer_pool):
        chaindb.persist_header_to_db(genesis_header)
        return cls(chaindb, peer_pool)

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(LESPeer, peer)))

    async def handle_peer(self, peer: LESPeer) -> None:
        """Handle the lifecycle of the given peer.

        Returns when the peer is finished or when the LightChain is asked to stop.
        """
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: LESPeer) -> None:
        self._announcement_queue.put_nowait((peer, peer.head_info))
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either the peer disconnected or our cancel_token has been triggered, so break
                # out of the loop to stop attempting to sync with this peer.
                break
            # We currently implement only the LES commands for retrieving data (apart from
            # Announce), and those should always come as a response to requests we make so will be
            # handled by LESPeer.handle_sub_proto_msg().
            if isinstance(cmd, les.Announce):
                peer.head_info = cmd.as_head_info(msg)
                self._announcement_queue.put_nowait((peer, peer.head_info))
            else:
                raise UnexpectedMessage("Unexpected msg from {}: {}".format(peer, msg))

        await self.drop_peer(peer)
        self.logger.debug("%s finished", peer)

    async def drop_peer(self, peer: LESPeer) -> None:
        self._last_processed_announcements.pop(peer, None)
        await peer.stop()

    async def wait_until_finished(self):
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers))
            await asyncio.sleep(0.1)
        else:
            self.logger.info("Waited too long for peers to finish, exiting anyway")

    async def get_best_peer(self) -> LESPeer:
        """
        Return the peer with the highest announced block height.
        """
        while not self.peer_pool.peers:
            self.logger.debug("No connected peers, sleeping a bit")
            await asyncio.sleep(0.5)

        def peer_block_height(peer: LESPeer) -> int:
            last_announced = self._last_processed_announcements.get(peer)
            if last_announced is None:
                return -1
            return last_announced.block_number

        # TODO: Should pick a random one in case there are multiple peers with the same block
        # height.
        return max(
            [cast(LESPeer, peer) for peer in self.peer_pool.peers],
            key=peer_block_height)

    async def wait_for_announcement(self) -> Tuple[LESPeer, les.HeadInfo]:
        """Wait for a new announcement from any of our connected peers.

        Returns a tuple containing the LESPeer on which the announcement was received and the
        announcement info.

        Raises OperationCancelled when LightChain.stop() has been called.
        """
        # Wait for either a new announcement or our cancel_token to be triggered.
        return await wait_with_token(self._announcement_queue.get(), self.cancel_token)

    async def run(self) -> None:
        """Run the LightChain, ensuring headers are in sync with connected peers.

        If .stop() is called, we'll disconnect from all peers and return.
        """
        self.logger.info("Running LightChain...")
        while True:
            try:
                peer, head_info = await self.wait_for_announcement()
            except OperationCancelled:
                self.logger.debug("Asked to stop, breaking out of run() loop")
                break

            try:
                await self.process_announcement(peer, head_info)
                self._last_processed_announcements[peer] = head_info
            except OperationCancelled:
                self.logger.debug("Asked to stop, breaking out of run() loop")
                break
            except LESAnnouncementProcessingError as e:
                self.logger.warning(repr(e))
                await self.drop_peer(peer)
            except Exception as e:
                self.logger.error(
                    "Unexpected error when processing announcement: %s", repr(e))
                await self.drop_peer(peer)

    async def fetch_headers(self, start_block: int, peer: LESPeer) -> List[BlockHeader]:
        for i in range(self.max_consecutive_timeouts):
            try:
                return await peer.fetch_headers_starting_at(start_block, self.cancel_token)
            except asyncio.TimeoutError:
                self.logger.info(
                    "Timeout when fetching headers from %s (attempt %d of %d)",
                    peer, i + 1, self.max_consecutive_timeouts)
                # TODO: Figure out what's a good value to use here.
                await asyncio.sleep(0.5)
        raise TooManyTimeouts()

    async def get_sync_start_block(self, peer: LESPeer, head_info: les.HeadInfo) -> int:
        chain_head = await self.chaindb.coro_get_canonical_head()
        last_peer_announcement = self._last_processed_announcements.get(peer)
        if chain_head.block_number == GENESIS_BLOCK_NUMBER:
            start_block = GENESIS_BLOCK_NUMBER
        elif last_peer_announcement is None:
            # It's the first time we hear from this peer, need to figure out which headers to
            # get from it.  We can't simply fetch headers starting from our current head
            # number because there may have been a chain reorg, so we fetch some headers prior
            # to our head from the peer, and insert any missing ones in our DB, essentially
            # making our canonical chain identical to the peer's up to
            # chain_head.block_number.
            oldest_ancestor_to_consider = max(
                0, chain_head.block_number - peer.max_headers_fetch + 1)
            try:
                headers = await self.fetch_headers(oldest_ancestor_to_consider, peer)
            except EmptyGetBlockHeadersReply:
                raise LESAnnouncementProcessingError(
                    "No common ancestors found between us and {}".format(peer))
            except TooManyTimeouts:
                raise LESAnnouncementProcessingError(
                    "Too many timeouts when fetching headers from {}".format(peer))
            for header in headers:
                await self.chaindb.coro_persist_header_to_db(header)
            start_block = chain_head.block_number
        else:
            start_block = last_peer_announcement.block_number - head_info.reorg_depth
        return start_block

    # TODO: Distribute requests among our peers, ensuring the selected peer has the info we want
    # and respecting the flow control rules.
    async def process_announcement(self, peer: LESPeer, head_info: les.HeadInfo) -> None:
        if await self.chaindb.coro_header_exists(head_info.block_hash):
            self.logger.debug(
                "Skipping processing of %s from %s as head has already been fetched",
                head_info, peer)
            return

        start_block = await self.get_sync_start_block(peer, head_info)
        while start_block < head_info.block_number:
            try:
                # We should use "start_block + 1" here, but we always re-fetch the last synced
                # block to work around https://github.com/ethereum/go-ethereum/issues/15447
                batch = await self.fetch_headers(start_block, peer)
            except TooManyTimeouts:
                raise LESAnnouncementProcessingError(
                    "Too many timeouts when fetching headers from {}".format(peer))
            for header in batch:
                await self.chaindb.coro_persist_header_to_db(header)
                start_block = header.block_number
            self.logger.info("synced headers up to #%s", start_block)

    async def stop(self):
        self.logger.info("Stopping LightChain...")
        self.cancel_token.trigger()
        self.logger.debug("Waiting for all pending tasks to finish...")
        await self.wait_until_finished()
        self.logger.debug("LightChain finished")

    async def get_canonical_block_by_number(self, block_number: int) -> BaseBlock:
        """Return the block with the given number from the canonical chain.

        Raises BlockNotFound if it is not found.
        """
        try:
            block_hash = await self.chaindb.coro_lookup_block_hash(block_number)
        except KeyError:
            raise BlockNotFound(
                "No block with number {} found on local chain".format(block_number))
        return await self.get_block_by_hash(block_hash)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_block_by_hash(self, block_hash: bytes) -> BaseBlock:
        peer = await self.get_best_peer()
        try:
            header = await self.chaindb.coro_get_block_header_by_hash(block_hash)
        except BlockNotFound:
            self.logger.debug("Fetching header %s from %s", encode_hex(block_hash), peer)
            header = await peer.get_block_header_by_hash(block_hash, self.cancel_token)

        self.logger.debug("Fetching block %s from %s", encode_hex(block_hash), peer)
        body = await peer.get_block_by_hash(block_hash, self.cancel_token)
        block_class = self.get_vm_class_for_block_number(header.block_number).get_block_class()
        transactions = [
            block_class.transaction_class.from_base_transaction(tx)
            for tx in body.transactions
        ]
        return block_class(
            header=header,
            transactions=transactions,
            uncles=body.uncles,
        )

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_receipts(self, block_hash: bytes) -> List[Receipt]:
        peer = await self.get_best_peer()
        self.logger.debug("Fetching %s receipts from %s", encode_hex(block_hash), peer)
        return await peer.get_receipts(block_hash, self.cancel_token)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_account(self, block_hash: bytes, address: bytes) -> Account:
        peer = await self.get_best_peer()
        return await peer.get_account(block_hash, address, self.cancel_token)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_contract_code(self, block_hash: bytes, key: bytes) -> bytes:
        peer = await self.get_best_peer()
        return await peer.get_contract_code(block_hash, key, self.cancel_token)
Ejemplo n.º 13
0
class LightChain(Chain, PeerPoolSubscriber):
    logger = logging.getLogger("p2p.lightchain.LightChain")
    max_consecutive_timeouts = 5

    def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool) -> None:
        super(LightChain, self).__init__(chaindb)
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self._announcement_queue = asyncio.Queue(
        )  # type: asyncio.Queue[Tuple[LESPeer, les.HeadInfo]]  # noqa: E501
        self._last_processed_announcements = {
        }  # type: Dict[LESPeer, les.HeadInfo]
        self.cancel_token = CancelToken('LightChain')
        self._running_peers = set()  # type: Set[LESPeer]

    @classmethod
    def from_genesis_header(cls, chaindb, genesis_header, peer_pool):
        chaindb.persist_header(genesis_header)
        return cls(chaindb, peer_pool)

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(LESPeer, peer)))

    async def handle_peer(self, peer: LESPeer) -> None:
        """Handle the lifecycle of the given peer.

        Returns when the peer is finished or when the LightChain is asked to stop.
        """
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: LESPeer) -> None:
        self._announcement_queue.put_nowait((peer, peer.head_info))
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either the peer disconnected or our cancel_token has been triggered, so break
                # out of the loop to stop attempting to sync with this peer.
                break
            # We currently implement only the LES commands for retrieving data (apart from
            # Announce), and those should always come as a response to requests we make so will be
            # handled by LESPeer.handle_sub_proto_msg().
            if isinstance(cmd, les.Announce):
                peer.head_info = cmd.as_head_info(msg)
                self._announcement_queue.put_nowait((peer, peer.head_info))
            else:
                raise UnexpectedMessage("Unexpected msg from {}: {}".format(
                    peer, msg))

        await self.drop_peer(peer)
        self.logger.debug("%s finished", peer)

    async def drop_peer(self, peer: LESPeer) -> None:
        self._last_processed_announcements.pop(peer, None)
        await peer.stop()

    async def wait_until_finished(self):
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            self.logger.debug("Waiting for %d running peers to finish",
                              len(self._running_peers))
            await asyncio.sleep(0.1)
        else:
            self.logger.info(
                "Waited too long for peers to finish, exiting anyway")

    async def get_best_peer(self) -> LESPeer:
        """
        Return the peer with the highest announced block height.
        """
        while not self.peer_pool.peers:
            self.logger.debug("No connected peers, sleeping a bit")
            await asyncio.sleep(0.5)

        def peer_block_height(peer: LESPeer) -> int:
            last_announced = self._last_processed_announcements.get(peer)
            if last_announced is None:
                return -1
            return last_announced.block_number

        # TODO: Should pick a random one in case there are multiple peers with the same block
        # height.
        return max([cast(LESPeer, peer) for peer in self.peer_pool.peers],
                   key=peer_block_height)

    async def wait_for_announcement(self) -> Tuple[LESPeer, les.HeadInfo]:
        """Wait for a new announcement from any of our connected peers.

        Returns a tuple containing the LESPeer on which the announcement was received and the
        announcement info.

        Raises OperationCancelled when LightChain.stop() has been called.
        """
        # Wait for either a new announcement or our cancel_token to be triggered.
        return await wait_with_token(self._announcement_queue.get(),
                                     token=self.cancel_token)

    async def run(self) -> None:
        """Run the LightChain, ensuring headers are in sync with connected peers.

        If .stop() is called, we'll disconnect from all peers and return.
        """
        self.logger.info("Running LightChain...")
        while True:
            try:
                peer, head_info = await self.wait_for_announcement()
            except OperationCancelled:
                self.logger.debug("Asked to stop, breaking out of run() loop")
                break

            try:
                await self.process_announcement(peer, head_info)
                self._last_processed_announcements[peer] = head_info
            except OperationCancelled:
                self.logger.debug("Asked to stop, breaking out of run() loop")
                break
            except LESAnnouncementProcessingError as e:
                self.logger.warning(repr(e))
                await self.drop_peer(peer)
            except Exception as e:
                self.logger.error(
                    "Unexpected error when processing announcement: %s",
                    repr(e))
                await self.drop_peer(peer)

    async def fetch_headers(self, start_block: int,
                            peer: LESPeer) -> List[BlockHeader]:
        for i in range(self.max_consecutive_timeouts):
            try:
                return await peer.fetch_headers_starting_at(
                    start_block, self.cancel_token)
            except TimeoutError:
                self.logger.info(
                    "Timeout when fetching headers from %s (attempt %d of %d)",
                    peer, i + 1, self.max_consecutive_timeouts)
                # TODO: Figure out what's a good value to use here.
                await asyncio.sleep(0.5)
        raise TooManyTimeouts()

    async def get_sync_start_block(self, peer: LESPeer,
                                   head_info: les.HeadInfo) -> int:
        chain_head = await self.chaindb.coro_get_canonical_head()
        last_peer_announcement = self._last_processed_announcements.get(peer)
        if chain_head.block_number == GENESIS_BLOCK_NUMBER:
            start_block = GENESIS_BLOCK_NUMBER
        elif last_peer_announcement is None:
            # It's the first time we hear from this peer, need to figure out which headers to
            # get from it.  We can't simply fetch headers starting from our current head
            # number because there may have been a chain reorg, so we fetch some headers prior
            # to our head from the peer, and insert any missing ones in our DB, essentially
            # making our canonical chain identical to the peer's up to
            # chain_head.block_number.
            oldest_ancestor_to_consider = max(
                0, chain_head.block_number - peer.max_headers_fetch + 1)
            try:
                headers = await self.fetch_headers(oldest_ancestor_to_consider,
                                                   peer)
            except EmptyGetBlockHeadersReply:
                raise LESAnnouncementProcessingError(
                    "No common ancestors found between us and {}".format(peer))
            except TooManyTimeouts:
                raise LESAnnouncementProcessingError(
                    "Too many timeouts when fetching headers from {}".format(
                        peer))
            for header in headers:
                await self.chaindb.coro_persist_header(header)
            start_block = chain_head.block_number
        else:
            start_block = last_peer_announcement.block_number - head_info.reorg_depth
        return start_block

    # TODO: Distribute requests among our peers, ensuring the selected peer has the info we want
    # and respecting the flow control rules.
    async def process_announcement(self, peer: LESPeer,
                                   head_info: les.HeadInfo) -> None:
        if await self.chaindb.coro_header_exists(head_info.block_hash):
            self.logger.debug(
                "Skipping processing of %s from %s as head has already been fetched",
                head_info, peer)
            return

        start_block = await self.get_sync_start_block(peer, head_info)
        while start_block < head_info.block_number:
            try:
                # We should use "start_block + 1" here, but we always re-fetch the last synced
                # block to work around https://github.com/ethereum/go-ethereum/issues/15447
                batch = await self.fetch_headers(start_block, peer)
            except TooManyTimeouts:
                raise LESAnnouncementProcessingError(
                    "Too many timeouts when fetching headers from {}".format(
                        peer))
            for header in batch:
                await self.chaindb.coro_persist_header(header)
                start_block = header.block_number
            self.logger.info("synced headers up to #%s", start_block)

    async def stop(self):
        self.logger.info("Stopping LightChain...")
        self.cancel_token.trigger()
        self.logger.debug("Waiting for all pending tasks to finish...")
        await self.wait_until_finished()
        self.logger.debug("LightChain finished")

    async def get_canonical_block_by_number(self,
                                            block_number: int) -> BaseBlock:
        """Return the block with the given number from the canonical chain.

        Raises BlockNotFound if it is not found.
        """
        try:
            block_hash = await self.chaindb.coro_lookup_block_hash(block_number
                                                                   )
        except KeyError:
            raise BlockNotFound(
                "No block with number {} found on local chain".format(
                    block_number))
        return await self.get_block_by_hash(block_hash)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_block_by_hash(self, block_hash: bytes) -> BaseBlock:
        peer = await self.get_best_peer()
        try:
            header = await self.chaindb.coro_get_block_header_by_hash(
                block_hash)
        except BlockNotFound:
            self.logger.debug("Fetching header %s from %s",
                              encode_hex(block_hash), peer)
            header = await peer.get_block_header_by_hash(
                block_hash, self.cancel_token)

        self.logger.debug("Fetching block %s from %s", encode_hex(block_hash),
                          peer)
        body = await peer.get_block_by_hash(block_hash, self.cancel_token)
        block_class = self.get_vm_class_for_block_number(
            header.block_number).get_block_class()
        transactions = [
            block_class.transaction_class.from_base_transaction(tx)
            for tx in body.transactions
        ]
        return block_class(
            header=header,
            transactions=transactions,
            uncles=body.uncles,
        )

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_receipts(self, block_hash: bytes) -> List[Receipt]:
        peer = await self.get_best_peer()
        self.logger.debug("Fetching %s receipts from %s",
                          encode_hex(block_hash), peer)
        return await peer.get_receipts(block_hash, self.cancel_token)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_account(self, block_hash: bytes, address: bytes) -> Account:
        peer = await self.get_best_peer()
        return await peer.get_account(block_hash, address, self.cancel_token)

    @alru_cache(maxsize=1024, cache_exceptions=False)
    async def get_contract_code(self, block_hash: bytes, key: bytes) -> bytes:
        peer = await self.get_best_peer()
        return await peer.get_contract_code(block_hash, key, self.cancel_token)
Ejemplo n.º 14
0
def test_token_single():
    token = CancelToken('token')
    assert not token.triggered
    token.trigger()
    assert token.triggered
    assert token.triggered_token == token
Ejemplo n.º 15
0
class ChainSyncer(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.chain.ChainSyncer")
    # We'll only sync if we are connected to at least min_peers_to_sync.
    min_peers_to_sync = 2

    def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool) -> None:
        self.chaindb = chaindb
        self.peer_pool = peer_pool
        self.peer_pool.min_peers = self.min_peers_to_sync
        self.peer_pool.subscribe(self)
        self.cancel_token = CancelToken('ChainSyncer')
        self._running_peers = set()  # type: Set[ETHPeer]
        self._syncing = False
        self._sync_requests = asyncio.Queue()  # type: asyncio.Queue[ETHPeer]
        self._new_headers = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))
        highest_td_peer = max(
            [cast(ETHPeer, peer) for peer in self.peer_pool.peers],
            key=operator.attrgetter('head_td'))
        self._sync_requests.put_nowait(highest_td_peer)

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break

            try:
                self.handle_msg(peer, cmd, msg)
            except Exception as e:
                self.logger.error(
                    "Unexpected error when processing msg from %s: %s", peer,
                    repr(e))
                break

    async def run(self) -> None:
        while True:
            try:
                peer = await wait_with_token(self._sync_requests.get(),
                                             token=self.cancel_token)
            except OperationCancelled:
                break

            asyncio.ensure_future(self.sync(peer))

            # TODO: If we're in light mode and we've synced up to head - 1024, trigger cancel
            # token to stop and raise an exception to tell our caller it should perform a state
            # sync.

    async def sync(self, peer: ETHPeer) -> None:
        if self._syncing:
            self.logger.debug(
                "Got a NewBlock or a new peer, but already syncing so doing nothing"
            )
            return
        elif len(self._running_peers) < self.min_peers_to_sync:
            self.logger.debug(
                "Connected to less peers (%d) than the minimum (%d) required to sync, "
                "doing nothing", len(self._running_peers),
                self.min_peers_to_sync)
            return

        self._syncing = True
        try:
            await self._sync(peer)
        finally:
            self._syncing = False

    async def _sync(self, peer: ETHPeer) -> None:
        head = await self.chaindb.coro_get_canonical_head()
        head_td = await self.chaindb.coro_get_score(head.hash)
        if peer.head_td <= head_td:
            self.logger.debug(
                "Head TD (%d) announced by %s not higher than ours (%d), not syncing",
                peer.head_td, peer, head_td)
            return

        # FIXME: Fetch a batch of headers, in reverse order, starting from our current head, and
        # find the common ancestor between our chain and the peer's.
        start_at = max(0, head.block_number - eth.MAX_HEADERS_FETCH)
        self.logger.debug("Asking %s for header batch starting at %d", peer,
                          start_at)
        peer.sub_proto.send_get_block_headers(start_at,
                                              eth.MAX_HEADERS_FETCH,
                                              reverse=False)
        max_consecutive_timeouts = 3
        consecutive_timeouts = 0
        while True:
            try:
                headers = await wait_with_token(self._new_headers.get(),
                                                peer.wait_until_finished(),
                                                token=self.cancel_token,
                                                timeout=3)
            except OperationCancelled:
                break
            except TimeoutError:
                self.logger.debug("Timeout waiting for header batch from %s",
                                  peer)
                consecutive_timeouts += 1
                if consecutive_timeouts > max_consecutive_timeouts:
                    self.logger.debug(
                        "Too many consecutive timeouts waiting for header batch, aborting sync "
                        "with %s", peer)
                    break
                continue

            if peer.is_finished():
                self.logger.debug("%s disconnected, stopping sync", peer)
                break

            consecutive_timeouts = 0
            if headers[-1].block_number <= start_at:
                self.logger.debug(
                    "Ignoring headers from %d to %d as they've been processed already",
                    headers[0].block_number, headers[-1].block_number)
                continue

            # TODO: Process headers for consistency.
            # TODO: Queue body/receipt downloads.
            for header in headers:
                await self.chaindb.coro_persist_header_to_db(header)
                start_at = header.block_number

            self.logger.debug("Asking %s for header batch starting at %d",
                              peer, start_at)
            # TODO: Instead of requesting sequential batches from a single peer, request a header
            # skeleton and make concurrent requests, using as many peers as possible, to fill the
            # skeleton.
            peer.sub_proto.send_get_block_headers(start_at,
                                                  eth.MAX_HEADERS_FETCH,
                                                  reverse=False)

    async def wait_until_finished(self) -> None:
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        self.logger.info("Waiting for %d running peers to finish",
                         len(self._running_peers))
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            await asyncio.sleep(0.1)
        else:
            self.logger.info(
                "Waited too long for peers to finish, exiting anyway")

    async def stop(self) -> None:
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        await self.wait_until_finished()

    def handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                   msg: protocol._DecodedMsgType) -> None:
        if isinstance(cmd, eth.BlockHeaders):
            msg = cast(List[BlockHeader], msg)
            self.logger.debug("Got BlockHeaders from %d to %d",
                              msg[0].block_number, msg[-1].block_number)
            self._new_headers.put_nowait(msg)
        elif isinstance(cmd, eth.BlockBodies):
            # TODO: Queue msg for processing by body downloader.
            pass
        elif isinstance(cmd, eth.Receipts):
            # TODO: Queue msg for processing by receipt downloader.
            pass
        elif isinstance(cmd, eth.NewBlock):
            msg = cast(Dict[str, Any], msg)
            header = msg['block'][0]
            actual_head = header.parent_hash
            actual_td = msg['total_difficulty'] - header.difficulty
            if actual_td > peer.head_td:
                peer.head_hash = actual_head
                peer.head_td = actual_td
                self._sync_requests.put_nowait(peer)
        elif isinstance(cmd, eth.Transactions):
            # TODO: Figure out what to do with those.
            pass
        elif isinstance(cmd, eth.NewBlockHashes):
            # TODO: Figure out what to do with those.
            pass
        else:
            # TODO: There are other msg types we'll want to handle here, but for now just log them
            # as a warning so we don't forget about it.
            self.logger.warn("Got unexpected msg: %s (%s)", cmd, msg)
Ejemplo n.º 16
0
class PeerPool:
    """PeerPool attempts to keep connections to at least min_peers on the given network."""
    logger = logging.getLogger("p2p.peer.PeerPool")
    min_peers = 2
    _connect_loop_sleep = 2

    def __init__(
        self,
        peer_class: Type[BasePeer],
        chaindb: ChainDB,
        network_id: int,
        privkey: datatypes.PrivateKey,
    ) -> None:
        self.peer_class = peer_class
        self.chaindb = chaindb
        self.network_id = network_id
        self.privkey = privkey
        self.connected_nodes = {}  # type: Dict[Node, BasePeer]
        self.cancel_token = CancelToken('PeerPool')
        self._subscribers = []  # type: List[PeerPoolSubscriber]

    def subscribe(self, subscriber: PeerPoolSubscriber) -> None:
        self._subscribers.append(subscriber)
        for peer in self.connected_nodes.values():
            subscriber.register_peer(peer)

    def unsubscribe(self, subscriber: PeerPoolSubscriber) -> None:
        if subscriber in self._subscribers:
            self._subscribers.remove(subscriber)

    async def get_nodes_to_connect(self) -> List[Node]:
        # TODO: This should use the Discovery service to lookup nodes to connect to, but our
        # current implementation only supports v4 and with that it takes an insane amount of time
        # to find any LES nodes with the same network ID as us, so for now we hard-code some nodes
        # that seem to have a good uptime.
        from evm.chains.ropsten import RopstenChain
        from evm.chains.mainnet import MainnetChain
        if self.network_id == MainnetChain.network_id:
            return [
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "1118980bf48b0a3640bdba04e0fe78b1add18e1cd99bf22d53daac1fd9972ad650df52176e7c7d89d1114cfef2bc23a2959aa54998a46afcf7d91809f0855082"
                        )),  # noqa: E501
                    Address("52.74.57.123", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "78de8a0916848093c73790ead81d1928bec737d565119932b98c6b100d944b7a95e94f847f689fc723399d2e31129d182f7ef3863f2b4c820abbf3ab2722344d"
                        )),  # noqa: E501
                    Address("191.235.84.50", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "ddd81193df80128880232fc1deb45f72746019839589eeb642d3d44efbb8b2dda2c1a46a348349964a6066f8afb016eb2a8c0f3c66f32fadf4370a236a4b5286"
                        )),  # noqa: E501
                    Address("52.231.202.145", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "3f1d12044546b76342d59d4a05532c14b85aa669704bfe1f864fe079415aa2c02d743e03218e57a33fb94523adb54032871a6c51b2cc5514cb7c7e35b3ed0a99"
                        )),  # noqa: E501
                    Address("13.93.211.84", 30303, 30303)),
            ]
        elif self.network_id == RopstenChain.network_id:
            return [
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "88c2b24429a6f7683fbfd06874ae3f1e7c8b4a5ffb846e77c705ba02e2543789d66fc032b6606a8d8888eb6239a2abe5897ce83f78dcdcfcb027d6ea69aa6fe9"
                        )),  # noqa: E501
                    Address("163.172.157.61", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "a1ef9ba5550d5fac27f7cbd4e8d20a643ad75596f307c91cd6e7f85b548b8a6bf215cca436d6ee436d6135f9fe51398f8dd4c0bd6c6a0c332ccb41880f33ec12"
                        )),  # noqa: E501
                    Address("51.15.218.125", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "e80276aabb7682a4a659f4341c1199de79d91a2e500a6ee9bed16ed4ce927ba8d32ba5dea357739ffdf2c5bcc848d3064bb6f149f0b4249c1f7e53f8bf02bfc8"
                        )),  # noqa: E501
                    Address("51.15.39.57", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "584c0db89b00719e9e7b1b5c32a4a8942f379f4d5d66bb69f9c7fa97fa42f64974e7b057b35eb5a63fd7973af063f9a1d32d8c60dbb4854c64cb8ab385470258"
                        )),  # noqa: E501
                    Address("51.15.35.2", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "d40871fc3e11b2649700978e06acd68a24af54e603d4333faecb70926ca7df93baa0b7bf4e927fcad9a7c1c07f9b325b22f6d1730e728314d0e4e6523e5cebc2"
                        )),  # noqa: E501
                    Address("51.15.132.235", 30303, 30303)),
                Node(
                    keys.PublicKey(
                        decode_hex(
                            "482484b9198530ee2e00db89791823244ca41dcd372242e2e1297dd06f6d8dd357603960c5ad9cc8dc15fcdf0e4edd06b7ad7db590e67a0b54f798c26581ebd7"
                        )),  # noqa: E501
                    Address("51.15.75.138", 30303, 30303)),
            ]
        else:
            raise ValueError("Unknown network_id: {}".format(self.network_id))

    async def run(self):
        self.logger.info("Running PeerPool...")
        while not self.cancel_token.triggered:
            try:
                await self.maybe_connect_to_more_peers()
            except:  # noqa: E722
                # Most unexpected errors should be transient, so we log and restart from scratch.
                self.logger.error("Unexpected error (%s), restarting",
                                  traceback.format_exc())
                await self.stop_all_peers()
            # Wait self._connect_loop_sleep seconds, unless we're asked to stop.
            await asyncio.wait([self.cancel_token.wait()],
                               timeout=self._connect_loop_sleep)

    async def stop_all_peers(self):
        self.logger.info("Stopping all peers ...")
        await asyncio.gather(
            *[peer.stop() for peer in self.connected_nodes.values()])

    async def stop(self):
        self.cancel_token.trigger()
        await self.stop_all_peers()

    async def connect(self, remote: Node) -> BasePeer:
        """
        Connect to the given remote and return a Peer instance when successful.
        Returns None if the remote is unreachable, times out or is useless.
        """
        if remote in self.connected_nodes:
            self.logger.debug("Skipping %s; already connected to it", remote)
            return None
        expected_exceptions = (UnreachablePeer, asyncio.TimeoutError,
                               PeerConnectionLost, HandshakeFailure)
        try:
            self.logger.info("Connecting to %s...", remote)
            # TODO: Use asyncio.wait() and our cancel_token here to cancel in case the token is
            # triggered.
            peer = await asyncio.wait_for(
                handshake(remote, self.privkey, self.peer_class, self.chaindb,
                          self.network_id), HANDSHAKE_TIMEOUT)
            return peer
        except expected_exceptions as e:
            self.logger.info("Could not complete handshake with %s: %s",
                             remote, repr(e))
        except Exception:
            self.logger.warning(
                "Unexpected error during auth/p2p handshake with %s: %s",
                remote, traceback.format_exc())
        return None

    async def maybe_connect_to_more_peers(self):
        """Connect to more peers if we're not yet connected to at least self.min_peers."""
        if len(self.connected_nodes) >= self.min_peers:
            self.logger.debug("Already connected to %s peers: %s; sleeping",
                              len(self.connected_nodes),
                              [remote for remote in self.connected_nodes])
            return

        for node in await self.get_nodes_to_connect():
            # TODO: Consider changing connect() to raise an exception instead of returning None,
            # as discussed in
            # https://github.com/pipermerriam/py-evm/pull/139#discussion_r152067425
            peer = await self.connect(node)
            if peer is not None:
                self.logger.info("Successfully connected to %s", peer)
                asyncio.ensure_future(
                    peer.run(finished_callback=self._peer_finished))
                self.connected_nodes[peer.remote] = peer
                for subscriber in self._subscribers:
                    subscriber.register_peer(peer)

    def _peer_finished(self, peer: BasePeer) -> None:
        """Remove the given peer from our list of connected nodes.
        This is passed as a callback to be called when a peer finishes.
        """
        if peer.remote in self.connected_nodes:
            self.connected_nodes.pop(peer.remote)

    @property
    def peers(self) -> List[BasePeer]:
        return list(self.connected_nodes.values())
Ejemplo n.º 17
0
async def test_wait_with_token_operation_cancelled(event_loop):
    token = CancelToken('token')
    token.trigger()
    with pytest.raises(OperationCancelled):
        await wait_with_token(asyncio.sleep(0.02), token)
    await assert_only_current_task_not_done()
Ejemplo n.º 18
0
class Server:
    """Server listening for incoming connections"""
    logger = logging.getLogger("p2p.server.Server")
    _server = None

    def __init__(
        self,
        privkey: datatypes.PrivateKey,
        server_address: Address,
        chaindb: AsyncChainDB,
        bootstrap_nodes: List[str],
        network_id: int,
        min_peers: int = DEFAULT_MIN_PEERS,
        peer_class: Type[BasePeer] = ETHPeer,
    ) -> None:
        self.cancel_token = CancelToken('Server')
        self.chaindb = chaindb
        self.privkey = privkey
        self.server_address = server_address
        self.network_id = network_id
        self.peer_class = peer_class
        # TODO: bootstrap_nodes should be looked up by network_id.
        discovery = DiscoveryProtocol(self.privkey,
                                      self.server_address,
                                      bootstrap_nodes=bootstrap_nodes)
        self.peer_pool = PeerPool(
            peer_class,
            self.chaindb,
            self.network_id,
            self.privkey,
            discovery,
            min_peers=min_peers,
        )

    async def start(self) -> None:
        self._server = await asyncio.start_server(
            self.receive_handshake,
            host=self.server_address.ip,
            port=self.server_address.tcp_port,
        )

    async def run(self) -> None:
        await self.start()
        self.logger.info("Running server...")
        self.logger.info(
            "enode://%s@%s:%s",
            self.privkey.public_key.to_hex()[2:],
            self.server_address.ip,
            self.server_address.tcp_port,
        )
        await self.cancel_token.wait()
        await self.stop()

    async def stop(self) -> None:
        self.logger.info("Closing server...")
        self.cancel_token.trigger()
        self._server.close()
        await self._server.wait_closed()
        await self.peer_pool.stop()

    async def receive_handshake(self, reader: asyncio.StreamReader,
                                writer: asyncio.StreamWriter) -> None:
        await wait_with_token(
            self._receive_handshake(reader, writer),
            token=self.cancel_token,
            timeout=HANDSHAKE_TIMEOUT,
        )

    async def _receive_handshake(self, reader: asyncio.StreamReader,
                                 writer: asyncio.StreamWriter) -> None:
        self.logger.debug("Receiving handshake...")
        # Use reader to read the auth_init msg until EOF
        msg = await wait_with_token(
            reader.read(ENCRYPTED_AUTH_MSG_LEN),
            token=self.cancel_token,
        )

        # Use decode_authentication(auth_init_message) on auth init msg
        try:
            ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                msg, self.privkey)
        # Try to decode as EIP8
        except DecryptionError:
            msg_size = big_endian_to_int(msg[:2])
            remaining_bytes = msg_size - ENCRYPTED_AUTH_MSG_LEN + 2
            msg += await wait_with_token(
                reader.read(remaining_bytes),
                token=self.cancel_token,
            )
            ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                msg, self.privkey)

        # Get remote's address: IPv4 or IPv6
        ip, socket, *_ = writer.get_extra_info("peername")
        remote_address = Address(ip, socket)

        # Create `HandshakeResponder(remote: kademlia.Node, privkey: datatypes.PrivateKey)` instance
        initiator_remote = Node(initiator_pubkey, remote_address)
        responder = HandshakeResponder(initiator_remote, self.privkey)

        # Call `HandshakeResponder.create_auth_ack_message(nonce: bytes)` to create the reply
        responder_nonce = secrets.token_bytes(HASH_LEN)
        auth_ack_msg = responder.create_auth_ack_message(nonce=responder_nonce)
        auth_ack_ciphertext = responder.encrypt_auth_ack_message(auth_ack_msg)

        # Use the `writer` to send the reply to the remote
        writer.write(auth_ack_ciphertext)
        await writer.drain()

        # Call `HandshakeResponder.derive_shared_secrets()` and use return values to create `Peer`
        aes_secret, mac_secret, egress_mac, ingress_mac = responder.derive_secrets(
            initiator_nonce=initiator_nonce,
            responder_nonce=responder_nonce,
            remote_ephemeral_pubkey=ephem_pubkey,
            auth_init_ciphertext=msg,
            auth_ack_ciphertext=auth_ack_ciphertext)

        # Create and register peer in peer_pool
        peer = self.peer_class(remote=initiator_remote,
                               privkey=self.privkey,
                               reader=reader,
                               writer=writer,
                               aes_secret=aes_secret,
                               mac_secret=mac_secret,
                               egress_mac=egress_mac,
                               ingress_mac=ingress_mac,
                               chaindb=self.chaindb,
                               network_id=self.network_id)

        await self.do_p2p_handshake(peer)

    async def do_p2p_handshake(self, peer: BasePeer) -> None:
        try:
            # P2P Handshake.
            await peer.do_p2p_handshake(),
        except (HandshakeFailure, asyncio.TimeoutError) as e:
            self.logger.debug('Unable to finish P2P handshake: %s', str(e))
        else:
            # Run peer and add peer.
            self.peer_pool.start_peer(peer)
Ejemplo n.º 19
0
async def test_wait_with_token_operation_cancelled(event_loop):
    token = CancelToken('token')
    token.trigger()
    with pytest.raises(OperationCancelled):
        await wait_with_token(asyncio.sleep(0.02), token)
    await assert_only_current_task_not_done()
Ejemplo n.º 20
0
class StateDownloader(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.state.StateDownloader")
    _pending_nodes = {}  # type: Dict[Any, float]
    _total_processed_nodes = 0
    _report_interval = 10  # Number of seconds between progress reports.
    _reply_timeout = 20  # seconds
    _start_time = None  # type: float
    _total_timeouts = 0

    def __init__(self,
                 state_db: BaseDB,
                 root_hash: bytes,
                 peer_pool: PeerPool,
                 token: CancelToken = None) -> None:
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.root_hash = root_hash
        self.scheduler = StateSync(root_hash, state_db)
        self._running_peers = set()  # type: Set[ETHPeer]
        self._peers_with_pending_requests = {}  # type: Dict[ETHPeer, float]
        self.cancel_token = CancelToken('StateDownloader')
        if token is not None:
            self.cancel_token = self.cancel_token.chain(token)

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))

    @property
    def idle_peers(self) -> List[ETHPeer]:
        peers = set([cast(ETHPeer, peer) for peer in self.peer_pool.peers])
        return list(peers.difference(self._peers_with_pending_requests))

    async def get_idle_peer(self) -> ETHPeer:
        while not self.idle_peers:
            self.logger.debug("Waiting for an idle peer...")
            await wait_with_token(asyncio.sleep(0.02), token=self.cancel_token)
        return secrets.choice(self.idle_peers)

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break

            # Run self._handle_msg() with ensure_future() instead of awaiting for it so that we
            # can keep consuming msgs while _handle_msg() performs cpu-intensive tasks in separate
            # processes.
            asyncio.ensure_future(self._handle_msg(peer, cmd, msg))

    async def _handle_msg(
            self, peer: ETHPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        loop = asyncio.get_event_loop()
        if isinstance(cmd, eth.NodeData):
            self.logger.debug("Got %d NodeData entries from %s", len(msg), peer)

            # Check before we remove because sometimes a reply may come after our timeout and in
            # that case we won't be expecting it anymore.
            if peer in self._peers_with_pending_requests:
                self._peers_with_pending_requests.pop(peer)

            node_keys = await loop.run_in_executor(None, list, map(keccak, msg))
            for node_key, node in zip(node_keys, msg):
                self._total_processed_nodes += 1
                try:
                    self.scheduler.process([(node_key, node)])
                except SyncRequestAlreadyProcessed:
                    # This means we received a node more than once, which can happen when we
                    # retry after a timeout.
                    pass
                # A node may be received more than once, so pop() with a default value.
                self._pending_nodes.pop(node_key, None)
        else:
            # We ignore everything that is not a NodeData when doing a StateSync.
            self.logger.debug("Ignoring %s msg while doing a StateSync", cmd)

    async def stop(self):
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        while self._running_peers:
            self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers))
            await asyncio.sleep(0.1)

    async def request_nodes(self, node_keys: List[bytes]) -> None:
        batches = list(partition_all(eth.MAX_STATE_FETCH, node_keys))
        for batch in batches:
            peer = await self.get_idle_peer()
            now = time.time()
            for node_key in batch:
                self._pending_nodes[node_key] = now
            self.logger.debug("Requesting %d trie nodes to %s", len(batch), peer)
            peer.sub_proto.send_get_node_data(batch)
            self._peers_with_pending_requests[peer] = now

    async def _periodically_retry_timedout(self):
        while True:
            now = time.time()
            # First, update our list of peers with pending requests by removing those for which a
            # request timed out. This loop mutates the dict, so we iterate on a copy of it.
            for peer, last_req_time in list(self._peers_with_pending_requests.items()):
                if now - last_req_time > self._reply_timeout:
                    self._peers_with_pending_requests.pop(peer)

            # Now re-send requests for nodes that timed out.
            oldest_request_time = now
            timed_out = []
            for node_key, req_time in self._pending_nodes.items():
                if now - req_time > self._reply_timeout:
                    timed_out.append(node_key)
                elif req_time < oldest_request_time:
                    oldest_request_time = req_time
            if timed_out:
                self.logger.debug("Re-requesting %d trie nodes", len(timed_out))
                self._total_timeouts += len(timed_out)
                try:
                    await self.request_nodes(timed_out)
                except OperationCancelled:
                    break

            # Finally, sleep until the time our oldest request is scheduled to timeout.
            now = time.time()
            sleep_duration = (oldest_request_time + self._reply_timeout) - now
            try:
                await wait_with_token(asyncio.sleep(sleep_duration), token=self.cancel_token)
            except OperationCancelled:
                break

    async def run(self):
        """Fetch all trie nodes starting from self.root_hash, and store them in self.db.

        Raises OperationCancelled if we're interrupted before that is completed.
        """
        self._start_time = time.time()
        self.logger.info("Starting state sync for root hash %s", encode_hex(self.root_hash))
        asyncio.ensure_future(self._periodically_report_progress())
        asyncio.ensure_future(self._periodically_retry_timedout())
        while self.scheduler.has_pending_requests:
            # This ensures we yield control and give _handle_msg() a chance to process any nodes
            # we may have received already, also ensuring we exit when our cancel token is
            # triggered.
            await wait_with_token(asyncio.sleep(0), token=self.cancel_token)

            requests = self.scheduler.next_batch(eth.MAX_STATE_FETCH)
            if not requests:
                # Although we frequently yield control above, to let our msg handler process
                # received nodes (scheduling new requests), there may be cases when the
                # pending nodes take a while to arrive thus causing the scheduler to run out
                # of new requests for a while.
                self.logger.info("Scheduler queue is empty, sleeping a bit")
                await wait_with_token(asyncio.sleep(0.5), token=self.cancel_token)
                continue

            await self.request_nodes([request.node_key for request in requests])

        self.logger.info("Finished state sync with root hash %s", encode_hex(self.root_hash))

    async def _periodically_report_progress(self):
        while True:
            now = time.time()
            self.logger.info("====== State sync progress ========")
            self.logger.info("Nodes processed: %d", self._total_processed_nodes)
            self.logger.info("Nodes processed per second (average): %d",
                             self._total_processed_nodes / (now - self._start_time))
            self.logger.info("Nodes committed to DB: %d", self.scheduler.committed_nodes)
            self.logger.info(
                "Nodes requested but not received yet: %d", len(self._pending_nodes))
            self.logger.info(
                "Nodes scheduled but not requested yet: %d", len(self.scheduler.requests))
            self.logger.info("Total nodes timed out: %d", self._total_timeouts)
            try:
                await wait_with_token(asyncio.sleep(self._report_interval), token=self.cancel_token)
            except OperationCancelled:
                break
Ejemplo n.º 21
0
class PeerPool:
    """PeerPool attempts to keep connections to at least min_peers on the given network."""
    logger = logging.getLogger("p2p.peer.PeerPool")
    min_peers = 2
    _connect_loop_sleep = 2

    def __init__(self,
                 peer_class: Type[BasePeer],
                 chaindb: AsyncChainDB,
                 network_id: int,
                 privkey: datatypes.PrivateKey,
                 ) -> None:
        self.peer_class = peer_class
        self.chaindb = chaindb
        self.network_id = network_id
        self.privkey = privkey
        self.connected_nodes = {}  # type: Dict[Node, BasePeer]
        self.cancel_token = CancelToken('PeerPool')
        self._subscribers = []  # type: List[PeerPoolSubscriber]

    def subscribe(self, subscriber: PeerPoolSubscriber) -> None:
        self._subscribers.append(subscriber)
        for peer in self.connected_nodes.values():
            subscriber.register_peer(peer)

    def unsubscribe(self, subscriber: PeerPoolSubscriber) -> None:
        if subscriber in self._subscribers:
            self._subscribers.remove(subscriber)

    async def get_nodes_to_connect(self) -> List[Node]:
        # TODO: This should use the Discovery service to lookup nodes to connect to, but our
        # current implementation only supports v4 and with that it takes an insane amount of time
        # to find any LES nodes with the same network ID as us, so for now we hard-code some nodes
        # that seem to have a good uptime.
        from evm.chains.ropsten import RopstenChain
        from evm.chains.mainnet import MainnetChain
        if self.network_id == MainnetChain.network_id:
            return [
                Node(
                    keys.PublicKey(decode_hex("1118980bf48b0a3640bdba04e0fe78b1add18e1cd99bf22d53daac1fd9972ad650df52176e7c7d89d1114cfef2bc23a2959aa54998a46afcf7d91809f0855082")),  # noqa: E501
                    Address("52.74.57.123", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("78de8a0916848093c73790ead81d1928bec737d565119932b98c6b100d944b7a95e94f847f689fc723399d2e31129d182f7ef3863f2b4c820abbf3ab2722344d")),  # noqa: E501
                    Address("191.235.84.50", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("ddd81193df80128880232fc1deb45f72746019839589eeb642d3d44efbb8b2dda2c1a46a348349964a6066f8afb016eb2a8c0f3c66f32fadf4370a236a4b5286")),  # noqa: E501
                    Address("52.231.202.145", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("3f1d12044546b76342d59d4a05532c14b85aa669704bfe1f864fe079415aa2c02d743e03218e57a33fb94523adb54032871a6c51b2cc5514cb7c7e35b3ed0a99")),  # noqa: E501
                    Address("13.93.211.84", 30303, 30303)),
            ]
        elif self.network_id == RopstenChain.network_id:
            return [
                Node(
                    keys.PublicKey(decode_hex("88c2b24429a6f7683fbfd06874ae3f1e7c8b4a5ffb846e77c705ba02e2543789d66fc032b6606a8d8888eb6239a2abe5897ce83f78dcdcfcb027d6ea69aa6fe9")),  # noqa: E501
                    Address("163.172.157.61", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("a1ef9ba5550d5fac27f7cbd4e8d20a643ad75596f307c91cd6e7f85b548b8a6bf215cca436d6ee436d6135f9fe51398f8dd4c0bd6c6a0c332ccb41880f33ec12")),  # noqa: E501
                    Address("51.15.218.125", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("e80276aabb7682a4a659f4341c1199de79d91a2e500a6ee9bed16ed4ce927ba8d32ba5dea357739ffdf2c5bcc848d3064bb6f149f0b4249c1f7e53f8bf02bfc8")),  # noqa: E501
                    Address("51.15.39.57", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("584c0db89b00719e9e7b1b5c32a4a8942f379f4d5d66bb69f9c7fa97fa42f64974e7b057b35eb5a63fd7973af063f9a1d32d8c60dbb4854c64cb8ab385470258")),  # noqa: E501
                    Address("51.15.35.2", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("d40871fc3e11b2649700978e06acd68a24af54e603d4333faecb70926ca7df93baa0b7bf4e927fcad9a7c1c07f9b325b22f6d1730e728314d0e4e6523e5cebc2")),  # noqa: E501
                    Address("51.15.132.235", 30303, 30303)),
                Node(
                    keys.PublicKey(decode_hex("482484b9198530ee2e00db89791823244ca41dcd372242e2e1297dd06f6d8dd357603960c5ad9cc8dc15fcdf0e4edd06b7ad7db590e67a0b54f798c26581ebd7")),  # noqa: E501
                    Address("51.15.75.138", 30303, 30303)),
            ]
        else:
            raise ValueError("Unknown network_id: {}".format(self.network_id))

    async def run(self):
        self.logger.info("Running PeerPool...")
        while not self.cancel_token.triggered:
            try:
                await self.maybe_connect_to_more_peers()
            except:  # noqa: E722
                # Most unexpected errors should be transient, so we log and restart from scratch.
                self.logger.error("Unexpected error (%s), restarting", traceback.format_exc())
                await self.stop_all_peers()
            # Wait self._connect_loop_sleep seconds, unless we're asked to stop.
            await asyncio.wait([self.cancel_token.wait()], timeout=self._connect_loop_sleep)

    async def stop_all_peers(self):
        self.logger.info("Stopping all peers ...")
        await asyncio.gather(
            *[peer.stop() for peer in self.connected_nodes.values()])

    async def stop(self):
        self.cancel_token.trigger()
        await self.stop_all_peers()

    async def connect(self, remote: Node) -> BasePeer:
        """
        Connect to the given remote and return a Peer instance when successful.
        Returns None if the remote is unreachable, times out or is useless.
        """
        if remote in self.connected_nodes:
            self.logger.debug("Skipping %s; already connected to it", remote)
            return None
        expected_exceptions = (
            UnreachablePeer, asyncio.TimeoutError, PeerConnectionLost, HandshakeFailure)
        try:
            self.logger.info("Connecting to %s...", remote)
            # TODO: Use asyncio.wait() and our cancel_token here to cancel in case the token is
            # triggered.
            peer = await asyncio.wait_for(
                handshake(remote, self.privkey, self.peer_class, self.chaindb, self.network_id),
                HANDSHAKE_TIMEOUT)
            return peer
        except expected_exceptions as e:
            self.logger.info("Could not complete handshake with %s: %s", remote, repr(e))
        except Exception:
            self.logger.warning("Unexpected error during auth/p2p handshake with %s: %s",
                                remote, traceback.format_exc())
        return None

    async def maybe_connect_to_more_peers(self):
        """Connect to more peers if we're not yet connected to at least self.min_peers."""
        if len(self.connected_nodes) >= self.min_peers:
            self.logger.debug(
                "Already connected to %s peers: %s; sleeping",
                len(self.connected_nodes),
                [remote for remote in self.connected_nodes])
            return

        for node in await self.get_nodes_to_connect():
            # TODO: Consider changing connect() to raise an exception instead of returning None,
            # as discussed in
            # https://github.com/pipermerriam/py-evm/pull/139#discussion_r152067425
            peer = await self.connect(node)
            if peer is not None:
                self.logger.info("Successfully connected to %s", peer)
                asyncio.ensure_future(peer.run(finished_callback=self._peer_finished))
                self.connected_nodes[peer.remote] = peer
                for subscriber in self._subscribers:
                    subscriber.register_peer(peer)

    def _peer_finished(self, peer: BasePeer) -> None:
        """Remove the given peer from our list of connected nodes.
        This is passed as a callback to be called when a peer finishes.
        """
        if peer.remote in self.connected_nodes:
            self.connected_nodes.pop(peer.remote)

    @property
    def peers(self) -> List[BasePeer]:
        return list(self.connected_nodes.values())
Ejemplo n.º 22
0
class BasePeer:
    logger = logging.getLogger("p2p.peer.Peer")
    conn_idle_timeout = CONN_IDLE_TIMEOUT
    reply_timeout = REPLY_TIMEOUT
    # Must be defined in subclasses. All items here must be Protocol classes representing
    # different versions of the same P2P sub-protocol (e.g. ETH, LES, etc).
    _supported_sub_protocols = []  # type: List[Type[protocol.Protocol]]
    # FIXME: Must be configurable.
    listen_port = 30303
    # Will be set upon the successful completion of a P2P handshake.
    sub_proto = None  # type: protocol.Protocol

    def __init__(self,
                 remote: Node,
                 privkey: datatypes.PrivateKey,
                 reader: asyncio.StreamReader,
                 writer: asyncio.StreamWriter,
                 aes_secret: bytes,
                 mac_secret: bytes,
                 egress_mac: PreImage,
                 ingress_mac: PreImage,
                 chaindb: AsyncChainDB,
                 network_id: int,
                 ) -> None:
        self._finished = asyncio.Event()
        self.remote = remote
        self.privkey = privkey
        self.reader = reader
        self.writer = writer
        self.base_protocol = P2PProtocol(self)
        self.chaindb = chaindb
        self.network_id = network_id
        self.sub_proto_msg_queue = asyncio.Queue()  # type: asyncio.Queue[Tuple[protocol.Command, protocol._DecodedMsgType]]  # noqa: E501
        self.cancel_token = CancelToken('Peer')

        self.egress_mac = egress_mac
        self.ingress_mac = ingress_mac
        # FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
        iv = b"\x00" * 16
        aes_cipher = Cipher(algorithms.AES(aes_secret), modes.CTR(iv), default_backend())
        self.aes_enc = aes_cipher.encryptor()
        self.aes_dec = aes_cipher.decryptor()
        mac_cipher = Cipher(algorithms.AES(mac_secret), modes.ECB(), default_backend())
        self.mac_enc = mac_cipher.encryptor().update

    async def send_sub_proto_handshake(self):
        raise NotImplementedError()

    async def process_sub_proto_handshake(
            self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        raise NotImplementedError()

    async def do_sub_proto_handshake(self):
        """Perform the handshake for the sub-protocol agreed with the remote peer.

        Raises HandshakeFailure if the handshake is not successful.
        """
        await self.send_sub_proto_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the sub-proto handshake.
            raise HandshakeFailure(
                "{} disconnected before completing sub-proto handshake: {}".format(
                    self, msg['reason_name']))
        await self.process_sub_proto_handshake(cmd, msg)
        self.logger.debug("Finished %s handshake with %s", self.sub_proto, self.remote)

    async def do_p2p_handshake(self):
        """Perform the handshake for the P2P base protocol.

        Raises HandshakeFailure if the handshake is not successful.
        """
        self.base_protocol.send_handshake()
        cmd, msg = await self.read_msg()
        if isinstance(cmd, Disconnect):
            # Peers sometimes send a disconnect msg before they send the initial P2P handshake.
            raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
                self, msg['reason_name']))
        self.process_p2p_handshake(cmd, msg)

    async def read_sub_proto_msg(
            self, cancel_token: CancelToken) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        """Read the next sub-protocol message from the queue.

        Raises OperationCancelled if the peer has been disconnected.
        """
        combined_token = self.cancel_token.chain(cancel_token)
        return await wait_with_token(self.sub_proto_msg_queue.get(), combined_token)

    @property
    async def genesis(self) -> BlockHeader:
        genesis_hash = await self.chaindb.coro_lookup_block_hash(GENESIS_BLOCK_NUMBER)
        return await self.chaindb.coro_get_block_header_by_hash(genesis_hash)

    @property
    async def _local_chain_info(self) -> 'ChainInfo':
        genesis = await self.genesis
        head = await self.chaindb.coro_get_canonical_head()
        total_difficulty = await self.chaindb.coro_get_score(head.hash)
        return ChainInfo(
            block_number=head.block_number,
            block_hash=head.hash,
            total_difficulty=total_difficulty,
            genesis_hash=genesis.hash,
        )

    @property
    def capabilities(self) -> List[Tuple[bytes, int]]:
        return [(klass.name, klass.version) for klass in self._supported_sub_protocols]

    def get_protocol_command_for(self, msg: bytes) -> protocol.Command:
        """Return the Command corresponding to the cmd_id encoded in the given msg."""
        cmd_id = get_devp2p_cmd_id(msg)
        self.logger.debug("Got msg with cmd_id: %s", cmd_id)
        if cmd_id < self.base_protocol.cmd_length:
            proto = self.base_protocol
        elif cmd_id < self.sub_proto.cmd_id_offset + self.sub_proto.cmd_length:
            proto = self.sub_proto  # type: ignore
        else:
            raise UnknownProtocolCommand("No protocol found for cmd_id {}".format(cmd_id))
        return proto.cmd_by_id[cmd_id]

    async def read(self, n: int) -> bytes:
        self.logger.debug("Waiting for %s bytes from %s", n, self.remote)
        try:
            return await wait_with_token(
                self.reader.readexactly(n), self.cancel_token, timeout=self.conn_idle_timeout)
        except (asyncio.IncompleteReadError, ConnectionResetError):
            raise PeerConnectionLost("EOF reading from stream")

    async def run(self, finished_callback: Optional[Callable[['BasePeer'], None]] = None) -> None:
        try:
            await self.read_loop()
        except OperationCancelled as e:
            self.logger.debug("Peer finished: %s", e)
        except Exception:
            self.logger.error(
                "Unexpected error when handling remote msg: %s", traceback.format_exc())
        finally:
            self._finished.set()
            if finished_callback is not None:
                finished_callback(self)

    def close(self):
        """Close this peer's reader/writer streams.

        This will cause the peer to stop in case it is running.

        If the streams have already been closed, do nothing.
        """
        if self.reader.at_eof():
            return
        self.reader.feed_eof()
        self.writer.close()

    async def stop(self):
        """Disconnect from the remote and flag this peer as finished.

        If the peer is already flagged as finished, do nothing.
        """
        if self._finished.is_set():
            return
        self.cancel_token.trigger()
        self.close()
        await self._finished.wait()
        self.logger.debug("Stopped %s", self)

    async def read_loop(self):
        while True:
            try:
                cmd, msg = await self.read_msg()
            except (PeerConnectionLost, asyncio.TimeoutError) as e:
                self.logger.info(
                    "%s stopped responding (%s), disconnecting", self.remote, repr(e))
                return

            self.process_msg(cmd, msg)

    async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
        header_data = await self.read(HEADER_LEN + MAC_LEN)
        header = self.decrypt_header(header_data)
        frame_size = self.get_frame_size(header)
        # The frame_size specified in the header does not include the padding to 16-byte boundary,
        # so need to do this here to ensure we read all the frame's data.
        read_size = roundup_16(frame_size)
        frame_data = await self.read(read_size + MAC_LEN)
        msg = self.decrypt_body(frame_data, frame_size)
        cmd = self.get_protocol_command_for(msg)
        decoded_msg = cmd.decode(msg)
        self.logger.debug("Successfully decoded %s msg: %s", cmd, decoded_msg)
        return cmd, decoded_msg

    def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        """Handle the base protocol (P2P) messages."""
        if isinstance(cmd, Disconnect):
            msg = cast(Dict[str, Any], msg)
            self.logger.debug(
                "%s disconnected; reason given: %s", self, msg['reason_name'])
            self.close()
        elif isinstance(cmd, Ping):
            self.base_protocol.send_pong()
        elif isinstance(cmd, Pong):
            # Currently we don't do anything when we get a pong, but eventually we should
            # update the last time we heard from a peer in our DB (which doesn't exist yet).
            pass
        else:
            raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))

    def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        self.sub_proto_msg_queue.put_nowait((cmd, msg))

    def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        if isinstance(cmd.proto, P2PProtocol):
            self.handle_p2p_msg(cmd, msg)
        else:
            self.handle_sub_proto_msg(cmd, msg)

    def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
        msg = cast(Dict[str, Any], msg)
        if not isinstance(cmd, Hello):
            self.disconnect(DisconnectReason.other)
            raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
        remote_capabilities = msg['capabilities']
        self.sub_proto = self.select_sub_protocol(remote_capabilities)
        if self.sub_proto is None:
            self.disconnect(DisconnectReason.useless_peer)
            raise HandshakeFailure(
                "No matching capabilities between us ({}) and {} ({}), disconnecting".format(
                    self.capabilities, self.remote, remote_capabilities))
        self.logger.debug(
            "Finished P2P handshake with %s, using sub-protocol %s",
            self.remote, self.sub_proto)

    def encrypt(self, header: bytes, frame: bytes) -> bytes:
        if len(header) != HEADER_LEN:
            raise ValueError("Unexpected header length: {}".format(len(header)))

        header_ciphertext = self.aes_enc.update(header)
        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), header_ciphertext))
        header_mac = self.egress_mac.digest()[:HEADER_LEN]

        frame_ciphertext = self.aes_enc.update(frame)
        self.egress_mac.update(frame_ciphertext)
        fmac_seed = self.egress_mac.digest()[:HEADER_LEN]

        mac_secret = self.egress_mac.digest()[:HEADER_LEN]
        self.egress_mac.update(sxor(self.mac_enc(mac_secret), fmac_seed))
        frame_mac = self.egress_mac.digest()[:HEADER_LEN]

        return header_ciphertext + header_mac + frame_ciphertext + frame_mac

    def decrypt_header(self, data: bytes) -> bytes:
        if len(data) != HEADER_LEN + MAC_LEN:
            raise ValueError("Unexpected header length: {}".format(len(data)))

        header_ciphertext = data[:HEADER_LEN]
        header_mac = data[HEADER_LEN:]
        mac_secret = self.ingress_mac.digest()[:HEADER_LEN]
        aes = self.mac_enc(mac_secret)[:HEADER_LEN]
        self.ingress_mac.update(sxor(aes, header_ciphertext))
        expected_header_mac = self.ingress_mac.digest()[:HEADER_LEN]
        if not bytes_eq(expected_header_mac, header_mac):
            raise AuthenticationError('Invalid header mac')
        return self.aes_dec.update(header_ciphertext)

    def decrypt_body(self, data: bytes, body_size: int) -> bytes:
        read_size = roundup_16(body_size)
        if len(data) < read_size + MAC_LEN:
            raise ValueError('Insufficient body length; Got {}, wanted {}'.format(
                len(data), (read_size + MAC_LEN)))

        frame_ciphertext = data[:read_size]
        frame_mac = data[read_size:read_size + MAC_LEN]

        self.ingress_mac.update(frame_ciphertext)
        fmac_seed = self.ingress_mac.digest()[:MAC_LEN]
        self.ingress_mac.update(sxor(self.mac_enc(fmac_seed), fmac_seed))
        expected_frame_mac = self.ingress_mac.digest()[:MAC_LEN]
        if not bytes_eq(expected_frame_mac, frame_mac):
            raise AuthenticationError('Invalid frame mac')
        return self.aes_dec.update(frame_ciphertext)[:body_size]

    def get_frame_size(self, header: bytes) -> int:
        # The frame size is encoded in the header as a 3-byte int, so before we unpack we need
        # to prefix it with an extra byte.
        encoded_size = b'\x00' + header[:3]
        (size,) = struct.unpack(b'>I', encoded_size)
        return size

    def send(self, header: bytes, body: bytes) -> None:
        cmd_id = rlp.decode(body[:1], sedes=sedes.big_endian_int)
        self.logger.debug("Sending msg with cmd_id: %s", cmd_id)
        self.writer.write(self.encrypt(header, body))

    def disconnect(self, reason: DisconnectReason) -> None:
        """Send a disconnect msg to the remote node and stop this Peer.

        :param reason: An item from the DisconnectReason enum.
        """
        if not isinstance(reason, DisconnectReason):
            self.logger.debug("Disconnecting from remote peer; reason: %s", reason.value)
            raise ValueError(
                "Reason must be an item of DisconnectReason, got {}".format(reason))
        self.base_protocol.send_disconnect(reason.value)
        self.close()

    def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
                            ) -> protocol.Protocol:
        """Select the sub-protocol to use when talking to the remote.

        Find the highest version of our supported sub-protocols that is also supported by the
        remote and stores an instance of it (with the appropriate cmd_id offset) in
        self.sub_proto.
        """
        matching_capabilities = set(self.capabilities).intersection(remote_capabilities)
        _, highest_matching_version = max(matching_capabilities, key=operator.itemgetter(1))
        offset = self.base_protocol.cmd_length
        for proto_class in self._supported_sub_protocols:
            if proto_class.version == highest_matching_version:
                return proto_class(self, offset)
        return None

    def __str__(self):
        return "{} {}".format(self.__class__.__name__, self.remote)
Ejemplo n.º 23
0
class DiscoveryProtocol(asyncio.DatagramProtocol):
    """A Kademlia-like protocol to discover RLPx nodes."""
    logger = logging.getLogger("p2p.discovery.DiscoveryProtocol")
    transport: asyncio.DatagramTransport = None
    _max_neighbours_per_packet_cache = None

    def __init__(self, privkey: datatypes.PrivateKey, address: kademlia.Address,
                 bootstrap_nodes: List[str]) -> None:
        self.privkey = privkey
        self.address = address
        self.bootstrap_nodes = [kademlia.Node.from_uri(node) for node in bootstrap_nodes]
        self.this_node = kademlia.Node(self.pubkey, address)
        self.kademlia = kademlia.KademliaProtocol(self.this_node, wire=self)
        self.cancel_token = CancelToken('DiscoveryProtocol')

    async def lookup_random(self, cancel_token: CancelToken) -> List[kademlia.Node]:
        return await self.kademlia.lookup_random(self.cancel_token.chain(cancel_token))

    def get_random_nodes(self, count: int) -> Generator[kademlia.Node, None, None]:
        return self.kademlia.routing.get_random_nodes(count)

    @property
    def pubkey(self) -> datatypes.PublicKey:
        return self.privkey.public_key

    def _get_handler(self, cmd: Command) -> Callable[[kademlia.Node, List[Any], bytes], None]:
        if cmd == CMD_PING:
            return self.recv_ping
        elif cmd == CMD_PONG:
            return self.recv_pong
        elif cmd == CMD_FIND_NODE:
            return self.recv_find_node
        elif cmd == CMD_NEIGHBOURS:
            return self.recv_neighbours
        else:
            raise ValueError("Unknwon command: {}".format(cmd))

    def _get_max_neighbours_per_packet(self):
        if self._max_neighbours_per_packet_cache is not None:
            return self._max_neighbours_per_packet_cache
        self._max_neighbours_per_packet_cache = _get_max_neighbours_per_packet()
        return self._max_neighbours_per_packet_cache

    def connection_made(self, transport):
        self.transport = transport

    async def bootstrap(self):
        while self.transport is None:
            # FIXME: Instead of sleeping here to wait until connection_made() is called to set
            # .transport we should instead only call it after we know it's been set.
            await asyncio.sleep(1)
        self.logger.debug("boostrapping with %s", self.bootstrap_nodes)
        await self.kademlia.bootstrap(self.bootstrap_nodes, self.cancel_token)

    # FIXME: Enable type checking here once we have a mypy version that
    # includes the fix for https://github.com/python/typeshed/pull/1740
    def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:  # type: ignore
        ip_address, udp_port = addr
        # XXX: For now we simply discard all v5 messages. The prefix below is what geth uses to
        # identify them:
        # https://github.com/ethereum/go-ethereum/blob/c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5/p2p/discv5/udp.go#L149  # noqa: E501
        if text_if_str(to_bytes, data).startswith(b"temporary discovery v5"):
            self.logger.debug("Got discovery v5 msg, discarding")
            return

        self.receive(kademlia.Address(ip_address, udp_port), data)  # type: ignore

    def error_received(self, exc: Exception) -> None:
        self.logger.error('error received: %s', exc)

    def send(self, node: kademlia.Node, message: bytes) -> None:
        self.transport.sendto(message, (node.address.ip, node.address.udp_port))

    async def stop(self):
        self.logger.info('stopping discovery')
        self.cancel_token.trigger()
        self.transport.close()
        # We run lots of asyncio tasks so this is to make sure they all get a chance to execute
        # and exit cleanly when they notice the cancel token has been triggered.
        await asyncio.sleep(1)

    def receive(self, address: kademlia.Address, message: bytes) -> None:
        try:
            remote_pubkey, cmd_id, payload, message_hash = _unpack(message)
        except DefectiveMessage as e:
            self.logger.error('error unpacking message (%s) from %s: %s', message, address, e)
            return

        # As of discovery version 4, expiration is the last element for all packets, so
        # we can validate that here, but if it changes we may have to do so on the
        # handler methods.
        expiration = rlp.sedes.big_endian_int.deserialize(payload[-1])
        if time.time() > expiration:
            self.logger.error('received message already expired')
            return

        cmd = CMD_ID_MAP[cmd_id]
        if len(payload) != cmd.elem_count:
            self.logger.error('invalid %s payload: %s', cmd.name, payload)
            return
        node = kademlia.Node(remote_pubkey, address)
        handler = self._get_handler(cmd)
        handler(node, payload, message_hash)

    def recv_pong(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None:
        # The pong payload should have 3 elements: to, token, expiration
        _, token, _ = payload
        self.kademlia.recv_pong(node, token)

    def recv_neighbours(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None:
        # The neighbours payload should have 2 elements: nodes, expiration
        nodes, _ = payload
        self.kademlia.recv_neighbours(node, _extract_nodes_from_payload(nodes))

    def recv_ping(self, node: kademlia.Node, _: Any, message_hash: bytes) -> None:
        self.kademlia.recv_ping(node, message_hash)

    def recv_find_node(self, node: kademlia.Node, payload: List[Any], _: bytes) -> None:
        # The find_node payload should have 2 elements: node_id, expiration
        self.logger.debug('<<< find_node from %s', node)
        node_id, _ = payload
        self.kademlia.recv_find_node(node, big_endian_to_int(node_id))

    def send_ping(self, node: kademlia.Node) -> bytes:
        version = rlp.sedes.big_endian_int.serialize(PROTO_VERSION)
        payload = [version, self.address.to_endpoint(), node.address.to_endpoint()]
        message = _pack(CMD_PING.id, payload, self.privkey)
        self.send(node, message)
        # Return the msg hash, which is used as a token to identify pongs.
        token = message[:MAC_SIZE]
        self.logger.debug('>>> ping %s (token == %s)', node, encode_hex(token))
        return token

    def send_find_node(self, node: kademlia.Node, target_node_id: int) -> None:
        node_id = int_to_big_endian(
            target_node_id).rjust(kademlia.k_pubkey_size // 8, b'\0')
        self.logger.debug('>>> find_node to %s', node)
        message = _pack(CMD_FIND_NODE.id, [node_id], self.privkey)
        self.send(node, message)

    def send_pong(self, node: kademlia.Node, token: bytes) -> None:
        self.logger.debug('>>> pong %s', node)
        payload = [node.address.to_endpoint(), token]
        message = _pack(CMD_PONG.id, payload, self.privkey)
        self.send(node, message)

    def send_neighbours(self, node: kademlia.Node, neighbours: List[kademlia.Node]) -> None:
        nodes = []
        neighbours = sorted(neighbours)
        for n in neighbours:
            nodes.append(n.address.to_endpoint() + [n.pubkey.to_bytes()])

        max_neighbours = self._get_max_neighbours_per_packet()
        for i in range(0, len(nodes), max_neighbours):
            message = _pack(CMD_NEIGHBOURS.id, [nodes[i:i + max_neighbours]], self.privkey)
            self.logger.debug('>>> neighbours to %s: %s',
                              node, neighbours[i:i + max_neighbours])
            self.send(node, message)
Ejemplo n.º 24
0
class ChainSyncer(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.chain.ChainSyncer")
    # We'll only sync if we are connected to at least min_peers_to_sync.
    min_peers_to_sync = 2
    _reply_timeout = 5

    def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool) -> None:
        self.chaindb = chaindb
        self.peer_pool = peer_pool
        self.peer_pool.min_peers = self.min_peers_to_sync
        self.peer_pool.subscribe(self)
        self.cancel_token = CancelToken('ChainSyncer')
        self._running_peers = set()  # type: Set[ETHPeer]
        self._syncing = False
        self._sync_requests = asyncio.Queue()  # type: asyncio.Queue[ETHPeer]
        self._new_headers = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        self._body_requests = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        self._receipt_requests = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        # A mapping from (transaction_root, uncles_hash) to (block_header, request time) so that
        # we can keep track of pending block bodies and retry them when necessary.
        self._pending_bodies = {
        }  # type: Dict[Tuple[bytes, bytes], Tuple[BlockHeader, float]]
        # A mapping from receipt_root to (block_header, request time) so that we can keep track of
        # pending block receipts and retry them when necessary.
        self._pending_receipts = {
        }  # type: Dict[bytes, Tuple[BlockHeader, float]]
        asyncio.ensure_future(self.body_downloader())
        asyncio.ensure_future(self.receipt_downloader())

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))
        highest_td_peer = max(
            [cast(ETHPeer, peer) for peer in self.peer_pool.peers],
            key=operator.attrgetter('head_td'))
        self._sync_requests.put_nowait(highest_td_peer)

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break

            try:
                await self.handle_msg(peer, cmd, msg)
            except Exception as e:
                self.logger.error(
                    "Unexpected error when processing msg from %s: %s", peer,
                    repr(e))
                break

    async def run(self) -> None:
        while True:
            try:
                peer = await wait_with_token(self._sync_requests.get(),
                                             token=self.cancel_token)
            except OperationCancelled:
                break

            asyncio.ensure_future(self.sync(peer))

            # TODO: If we're in light mode and we've synced up to head - 1024, wait until there's
            # no more pending bodies/receipts, trigger cancel token to stop and raise an exception
            # to tell our caller it should perform a state sync.

    async def sync(self, peer: ETHPeer) -> None:
        if self._syncing:
            self.logger.debug(
                "Got a NewBlock or a new peer, but already syncing so doing nothing"
            )
            return
        elif len(self._running_peers) < self.min_peers_to_sync:
            self.logger.debug(
                "Connected to less peers (%d) than the minimum (%d) required to sync, "
                "doing nothing", len(self._running_peers),
                self.min_peers_to_sync)
            return

        self._syncing = True
        try:
            await self._sync(peer)
        finally:
            self._syncing = False

    async def _sync(self, peer: ETHPeer) -> None:
        head = await self.chaindb.coro_get_canonical_head()
        head_td = await self.chaindb.coro_get_score(head.hash)
        if peer.head_td <= head_td:
            self.logger.debug(
                "Head TD (%d) announced by %s not higher than ours (%d), not syncing",
                peer.head_td, peer, head_td)
            return

        # FIXME: Fetch a batch of headers, in reverse order, starting from our current head, and
        # find the common ancestor between our chain and the peer's.
        start_at = max(0, head.block_number - eth.MAX_HEADERS_FETCH)
        self.logger.debug("Asking %s for header batch starting at %d", peer,
                          start_at)
        peer.sub_proto.send_get_block_headers(start_at,
                                              eth.MAX_HEADERS_FETCH,
                                              reverse=False)
        max_consecutive_timeouts = 3
        consecutive_timeouts = 0
        while True:
            try:
                headers = await wait_with_token(self._new_headers.get(),
                                                peer.wait_until_finished(),
                                                token=self.cancel_token,
                                                timeout=3)
            except OperationCancelled:
                break
            except TimeoutError:
                self.logger.debug("Timeout waiting for header batch from %s",
                                  peer)
                consecutive_timeouts += 1
                if consecutive_timeouts > max_consecutive_timeouts:
                    self.logger.debug(
                        "Too many consecutive timeouts waiting for header batch, aborting sync "
                        "with %s", peer)
                    break
                continue

            if peer.is_finished():
                self.logger.debug("%s disconnected, stopping sync", peer)
                break

            consecutive_timeouts = 0
            if headers[-1].block_number <= start_at:
                self.logger.debug(
                    "Ignoring headers from %d to %d as they've been processed already",
                    headers[0].block_number, headers[-1].block_number)
                continue

            # TODO: Process headers for consistency.
            for header in headers:
                await self.chaindb.coro_persist_header(header)
                start_at = header.block_number

            self._body_requests.put_nowait(headers)
            self._receipt_requests.put_nowait(headers)

            self.logger.debug("Asking %s for header batch starting at %d",
                              peer, start_at)
            # TODO: Instead of requesting sequential batches from a single peer, request a header
            # skeleton and make concurrent requests, using as many peers as possible, to fill the
            # skeleton.
            peer.sub_proto.send_get_block_headers(start_at,
                                                  eth.MAX_HEADERS_FETCH,
                                                  reverse=False)

    async def _downloader(
            self, queue: 'asyncio.Queue[List[BlockHeader]]',
            should_skip: Callable[[BlockHeader], bool],
            request_func: Callable[[List[BlockHeader]], Awaitable[None]],
            pending: Dict[Any, Tuple[BlockHeader, float]]) -> None:
        while True:
            try:
                headers = await wait_with_token(queue.get(),
                                                token=self.cancel_token,
                                                timeout=self._reply_timeout)
            except TimeoutError:
                # We use a timeout above to make sure we periodically retry timedout items
                # even when there's no new items coming through.
                pass
            except OperationCancelled:
                return
            else:
                await request_func(
                    [header for header in headers if not should_skip(header)])

            await self._retry_timedout(request_func, pending)

    async def _retry_timedout(
            self, request_func: Callable[[List[BlockHeader]], Awaitable[None]],
            pending: Dict[Any, Tuple[BlockHeader, float]]) -> None:
        now = time.time()
        timed_out = [
            header for header, req_time in pending.values()
            if now - req_time > self._reply_timeout
        ]
        if timed_out:
            self.logger.debug("Re-requesting %d timed out block parts...",
                              len(timed_out))
            await request_func(timed_out)

    async def body_downloader(self) -> None:
        await self._downloader(self._body_requests,
                               self._should_skip_body_download,
                               self.request_bodies, self._pending_bodies)

    async def receipt_downloader(self) -> None:
        await self._downloader(self._receipt_requests,
                               self._should_skip_receipts_download,
                               self.request_receipts, self._pending_receipts)

    def _should_skip_body_download(self, header: BlockHeader) -> bool:
        return (header.transaction_root == self.chaindb.empty_root_hash
                and header.uncles_hash == EMPTY_UNCLE_HASH)

    async def request_bodies(self, headers: List[BlockHeader]) -> None:
        for batch in partition_all(eth.MAX_BODIES_FETCH, headers):
            peer = await self.peer_pool.get_random_peer()
            cast(ETHPeer, peer).sub_proto.send_get_block_bodies(
                [header.hash for header in batch])
            self.logger.debug("Requesting %d block bodies to %s", len(batch),
                              peer)
            now = time.time()
            for header in batch:
                key = (header.transaction_root, header.uncles_hash)
                self._pending_bodies[key] = (header, now)

    def _should_skip_receipts_download(self, header: BlockHeader) -> bool:
        return header.receipt_root == self.chaindb.empty_root_hash

    async def request_receipts(self, headers: List[BlockHeader]) -> None:
        for batch in partition_all(eth.MAX_RECEIPTS_FETCH, headers):
            peer = await self.peer_pool.get_random_peer()
            cast(ETHPeer, peer).sub_proto.send_get_receipts(
                [header.hash for header in batch])
            self.logger.debug("Requesting %d block receipts to %s", len(batch),
                              peer)
            now = time.time()
            for header in batch:
                self._pending_receipts[header.receipt_root] = (header, now)

    async def wait_until_finished(self) -> None:
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        self.logger.info("Waiting for %d running peers to finish",
                         len(self._running_peers))
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            await asyncio.sleep(0.1)
        else:
            self.logger.info(
                "Waited too long for peers to finish, exiting anyway")

    async def stop(self) -> None:
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        await self.wait_until_finished()

    async def handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                         msg: protocol._DecodedMsgType) -> None:
        if isinstance(cmd, eth.BlockHeaders):
            msg = cast(List[BlockHeader], msg)
            self.logger.debug("Got BlockHeaders from %d to %d",
                              msg[0].block_number, msg[-1].block_number)
            self._new_headers.put_nowait(msg)
        elif isinstance(cmd, eth.BlockBodies):
            msg = cast(List[eth.BlockBody], msg)
            self.logger.debug("Got %d BlockBodies", len(msg))
            for body in msg:
                tx_root, trie_dict_data = make_trie_root_and_nodes(
                    body.transactions)
                await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
                # TODO: Add transactions to canonical chain; blocked by
                # https://github.com/ethereum/py-evm/issues/337
                uncles_hash = await self.chaindb.coro_persist_uncles(
                    body.uncles)
                self._pending_bodies.pop((tx_root, uncles_hash), None)
        elif isinstance(cmd, eth.Receipts):
            msg = cast(List[List[eth.Receipt]], msg)
            self.logger.debug("Got Receipts for %d blocks", len(msg))
            for block_receipts in msg:
                receipt_root, trie_dict_data = make_trie_root_and_nodes(
                    block_receipts)
                await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
                self._pending_receipts.pop(receipt_root, None)
        elif isinstance(cmd, eth.NewBlock):
            msg = cast(Dict[str, Any], msg)
            header = msg['block'][0]
            actual_head = header.parent_hash
            actual_td = msg['total_difficulty'] - header.difficulty
            if actual_td > peer.head_td:
                peer.head_hash = actual_head
                peer.head_td = actual_td
                self._sync_requests.put_nowait(peer)
        elif isinstance(cmd, eth.Transactions):
            # TODO: Figure out what to do with those.
            pass
        elif isinstance(cmd, eth.NewBlockHashes):
            # TODO: Figure out what to do with those.
            pass
        else:
            # TODO: There are other msg types we'll want to handle here, but for now just log them
            # as a warning so we don't forget about it.
            self.logger.warn("Got unexpected msg: %s (%s)", cmd, msg)
Ejemplo n.º 25
0
class ChainSyncer(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.chain.ChainSyncer")
    # We'll only sync if we are connected to at least min_peers_to_sync.
    min_peers_to_sync = 2
    # TODO: Instead of a fixed timeout, we should use a variable one that gets adjusted based on
    # the round-trip times from our download requests.
    _reply_timeout = 60

    def __init__(self,
                 chaindb: AsyncChainDB,
                 peer_pool: PeerPool,
                 token: CancelToken = None) -> None:
        self.chaindb = chaindb
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.cancel_token = CancelToken('ChainSyncer')
        if token is not None:
            self.cancel_token = self.cancel_token.chain(token)
        self._running_peers = set()  # type: Set[ETHPeer]
        self._syncing = False
        self._sync_complete = asyncio.Event()
        self._sync_requests = asyncio.Queue()  # type: asyncio.Queue[ETHPeer]
        self._new_headers = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        self._downloaded_receipts = asyncio.Queue(
        )  # type: asyncio.Queue[List[bytes]]
        self._downloaded_bodies = asyncio.Queue(
        )  # type: asyncio.Queue[List[Tuple[bytes, bytes]]]

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))
        highest_td_peer = max(
            [cast(ETHPeer, peer) for peer in self.peer_pool.peers],
            key=operator.attrgetter('head_td'))
        self._sync_requests.put_nowait(highest_td_peer)

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break

            pending_msgs = peer.sub_proto_msg_queue.qsize()
            if pending_msgs:
                self.logger.debug(
                    "Read %s msg from %s's queue; %d msgs pending", cmd, peer,
                    pending_msgs)

            # Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main
            # loop can keep processing msgs, and that's why we use ensure_future() instead of
            # awaiting for it to finish here.
            asyncio.ensure_future(self.handle_msg(peer, cmd, msg))

    async def handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                         msg: protocol._DecodedMsgType) -> None:
        try:
            await self._handle_msg(peer, cmd, msg)
        except OperationCancelled:
            # Silently swallow OperationCancelled exceptions because we run unsupervised (i.e.
            # with ensure_future()). Our caller will also get an OperationCancelled anyway, and
            # there it will be handled.
            pass
        except Exception:
            self.logger.exception(
                "Unexpected error when processing msg from %s", peer)

    async def run(self) -> None:
        while True:
            peer_or_finished = await wait_with_token(
                self._sync_requests.get(),
                self._sync_complete.wait(),
                token=self.cancel_token)

            if self._sync_complete.is_set():
                return

            # Since self._sync_complete is not set, peer_or_finished can only be a ETHPeer
            # instance.
            asyncio.ensure_future(self.sync(peer_or_finished))

    async def sync(self, peer: ETHPeer) -> None:
        if self._syncing:
            self.logger.debug(
                "Got a NewBlock or a new peer, but already syncing so doing nothing"
            )
            return
        elif len(self._running_peers) < self.min_peers_to_sync:
            self.logger.warn(
                "Connected to less peers (%d) than the minimum (%d) required to sync, "
                "doing nothing", len(self._running_peers),
                self.min_peers_to_sync)
            return

        self._syncing = True
        try:
            await self._sync(peer)
        except OperationCancelled:
            pass
        finally:
            self._syncing = False

    async def _sync(self, peer: ETHPeer) -> None:
        head = await self.chaindb.coro_get_canonical_head()
        head_td = await self.chaindb.coro_get_score(head.hash)
        if peer.head_td <= head_td:
            self.logger.info(
                "Head TD (%d) announced by %s not higher than ours (%d), not syncing",
                peer.head_td, peer, head_td)
            return

        self.logger.info("Starting sync with %s", peer)
        # FIXME: Fetch a batch of headers, in reverse order, starting from our current head, and
        # find the common ancestor between our chain and the peer's.
        start_at = max(0, head.block_number - eth.MAX_HEADERS_FETCH)
        while True:
            self.logger.info("Fetching chain segment starting at #%d",
                             start_at)
            peer.sub_proto.send_get_block_headers(start_at,
                                                  eth.MAX_HEADERS_FETCH,
                                                  reverse=False)
            try:
                headers = await wait_with_token(self._new_headers.get(),
                                                peer.wait_until_finished(),
                                                token=self.cancel_token,
                                                timeout=self._reply_timeout)
            except TimeoutError:
                self.logger.warn(
                    "Timeout waiting for header batch from %s, aborting sync",
                    peer)
                await peer.stop()
                break

            if peer.is_finished():
                self.logger.info("%s disconnected, aborting sync", peer)
                break

            self.logger.info("Got headers segment starting at #%d", start_at)

            # TODO: Process headers for consistency.

            await self._download_block_parts(
                [header for header in headers if not _is_body_empty(header)],
                self.request_bodies, self._downloaded_bodies, _body_key,
                'body')

            self.logger.info(
                "Got block bodies for chain segment starting at #%d", start_at)

            missing_receipts = [
                header for header in headers if not _is_receipts_empty(header)
            ]
            # Post-Byzantium blocks may have identical receipt roots (e.g. when they have the same
            # number of transactions and all succeed/failed: ropsten blocks 2503212 and 2503284),
            # so we do this to avoid requesting the same receipts multiple times.
            missing_receipts = list(unique(missing_receipts,
                                           key=_receipts_key))
            await self._download_block_parts(missing_receipts,
                                             self.request_receipts,
                                             self._downloaded_receipts,
                                             _receipts_key, 'receipt')

            self.logger.info(
                "Got block receipts for chain segment starting at #%d",
                start_at)

            for header in headers:
                await self.chaindb.coro_persist_header(header)
                start_at = header.block_number + 1

            self.logger.info("Imported chain segment, new head: #%d",
                             start_at - 1)
            head = await self.chaindb.coro_get_canonical_head()
            if head.hash == peer.head_hash:
                self.logger.info("Chain sync with %s completed", peer)
                self._sync_complete.set()
                break

    async def _download_block_parts(
            self,
            headers: List[BlockHeader],
            request_func: Callable[[List[BlockHeader]], int],
            download_queue:
        'Union[asyncio.Queue[List[bytes]], asyncio.Queue[List[Tuple[bytes, bytes]]]]',  # noqa: E501
            key_func: Callable[[BlockHeader], Union[bytes, Tuple[bytes,
                                                                 bytes]]],
            part_name: str) -> None:
        missing = headers.copy()
        # The ETH protocol doesn't guarantee that we'll get all body parts requested, so we need
        # to keep track of the number of pending replies and missing items to decide when to retry
        # them. See request_receipts() for more info.
        pending_replies = request_func(missing)
        while missing:
            if pending_replies == 0:
                pending_replies = request_func(missing)

            try:
                downloaded = await wait_with_token(download_queue.get(),
                                                   token=self.cancel_token,
                                                   timeout=self._reply_timeout)
            except TimeoutError:
                pending_replies = request_func(missing)
                continue

            pending_replies -= 1
            unexpected = set(downloaded).difference(
                [key_func(header) for header in missing])
            for item in unexpected:
                self.logger.warn("Got unexpected %s: %s", part_name,
                                 unexpected)
            missing = [
                header for header in missing
                if key_func(header) not in downloaded
            ]

    def _request_block_parts(
            self, headers: List[BlockHeader],
            request_func: Callable[[ETHPeer, List[BlockHeader]], None]) -> int:
        length = math.ceil(len(headers) / len(self.peer_pool.peers))
        batches = list(partition_all(length, headers))
        for peer, batch in zip(self.peer_pool.peers, batches):
            request_func(cast(ETHPeer, peer), batch)
        return len(batches)

    def _send_get_block_bodies(self, peer: ETHPeer,
                               headers: List[BlockHeader]) -> None:
        self.logger.debug("Requesting %d block bodies to %s", len(headers),
                          peer)
        peer.sub_proto.send_get_block_bodies(
            [header.hash for header in headers])

    def _send_get_receipts(self, peer: ETHPeer,
                           headers: List[BlockHeader]) -> None:
        self.logger.debug("Requesting %d block receipts to %s", len(headers),
                          peer)
        peer.sub_proto.send_get_receipts([header.hash for header in headers])

    def request_bodies(self, headers: List[BlockHeader]) -> int:
        """Ask our peers for bodies for the given headers.

        See request_receipts() for details of how this is done.
        """
        return self._request_block_parts(headers, self._send_get_block_bodies)

    def request_receipts(self, headers: List[BlockHeader]) -> int:
        """Ask our peers for receipts for the given headers.

        We partition the given list of headers in batches and request each to one of our connected
        peers. This is done because geth enforces a byte-size cap when replying to a GetReceipts
        msg, and we then need to re-request the items that didn't fit, so by splitting the
        requests across all our peers we reduce the likelyhood of having to make multiple
        serialized requests to ask for missing items (which happens quite frequently in practice).

        Returns the number of requests made.
        """
        return self._request_block_parts(headers, self._send_get_receipts)

    async def wait_until_finished(self) -> None:
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        self.logger.info("Waiting for %d running peers to finish",
                         len(self._running_peers))
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            await asyncio.sleep(0.1)
        else:
            self.logger.info(
                "Waited too long for peers to finish, exiting anyway")

    async def stop(self) -> None:
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        await self.wait_until_finished()

    async def _handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                          msg: protocol._DecodedMsgType) -> None:
        if isinstance(cmd, eth.BlockHeaders):
            msg = cast(List[BlockHeader], msg)
            self.logger.debug("Got BlockHeaders from %d to %d",
                              msg[0].block_number, msg[-1].block_number)
            self._new_headers.put_nowait(msg)
        elif isinstance(cmd, eth.BlockBodies):
            await self._handle_block_bodies(peer, cast(List[eth.BlockBody],
                                                       msg))
        elif isinstance(cmd, eth.Receipts):
            await self._handle_block_receipts(
                peer, cast(List[List[eth.Receipt]], msg))
        elif isinstance(cmd, eth.NewBlock):
            msg = cast(Dict[str, Any], msg)
            header = msg['block'][0]
            actual_head = header.parent_hash
            actual_td = msg['total_difficulty'] - header.difficulty
            if actual_td > peer.head_td:
                peer.head_hash = actual_head
                peer.head_td = actual_td
                self._sync_requests.put_nowait(peer)
        elif isinstance(cmd, eth.Transactions):
            # TODO: Figure out what to do with those.
            pass
        elif isinstance(cmd, eth.NewBlockHashes):
            # TODO: Figure out what to do with those.
            pass
        else:
            # TODO: There are other msg types we'll want to handle here, but for now just log them
            # as a warning so we don't forget about it.
            self.logger.warn("Got unexpected msg: %s (%s)", cmd, msg)

    async def _handle_block_receipts(
            self, peer: ETHPeer, receipts: List[List[eth.Receipt]]) -> None:
        self.logger.debug("Got Receipts for %d blocks from %s", len(receipts),
                          peer)
        loop = asyncio.get_event_loop()
        iterator = map(make_trie_root_and_nodes, receipts)
        receipts_tries = await wait_with_token(loop.run_in_executor(
            None, list, iterator),
                                               token=self.cancel_token)
        for receipt_root, trie_dict_data in receipts_tries:
            await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
        receipt_roots = [receipt_root for receipt_root, _ in receipts_tries]
        self._downloaded_receipts.put_nowait(receipt_roots)

    async def _handle_block_bodies(self, peer: ETHPeer,
                                   bodies: List[eth.BlockBody]) -> None:
        self.logger.debug("Got Bodies for %d blocks from %s", len(bodies),
                          peer)
        loop = asyncio.get_event_loop()
        iterator = map(make_trie_root_and_nodes,
                       [body.transactions for body in bodies])
        transactions_tries = await wait_with_token(loop.run_in_executor(
            None, list, iterator),
                                                   token=self.cancel_token)
        body_keys = []
        for (body, (tx_root, trie_dict_data)) in zip(bodies,
                                                     transactions_tries):
            await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
            uncles_hash = await self.chaindb.coro_persist_uncles(body.uncles)
            body_keys.append((tx_root, uncles_hash))
        self._downloaded_bodies.put_nowait(body_keys)
Ejemplo n.º 26
0
class ChainSyncer(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.chain.ChainSyncer")
    # We'll only sync if we are connected to at least min_peers_to_sync.
    min_peers_to_sync = 2
    # TODO: Instead of a fixed timeout, we should use a variable one that gets adjusted based on
    # the round-trip times from our download requests.
    _reply_timeout = 60

    def __init__(self, chaindb: AsyncChainDB, peer_pool: PeerPool) -> None:
        self.chaindb = chaindb
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.cancel_token = CancelToken('ChainSyncer')
        self._running_peers = set()  # type: Set[ETHPeer]
        self._peers_with_pending_requests = set()  # type: Set[ETHPeer]
        self._syncing = False
        self._sync_requests = asyncio.Queue()  # type: asyncio.Queue[ETHPeer]
        self._new_headers = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        self._body_requests = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        self._receipt_requests = asyncio.Queue(
        )  # type: asyncio.Queue[List[BlockHeader]]
        # A mapping from (transaction_root, uncles_hash) to (block_header, request time) so that
        # we can keep track of pending block bodies and retry them when necessary.
        self._pending_bodies = {
        }  # type: Dict[Tuple[bytes, bytes], Tuple[BlockHeader, float]]
        # A mapping from receipt_root to (block_header, request time) so that we can keep track of
        # pending block receipts and retry them when necessary.
        self._pending_receipts = {
        }  # type: Dict[bytes, Tuple[BlockHeader, float]]
        asyncio.ensure_future(self.body_downloader())
        asyncio.ensure_future(self.receipt_downloader())

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))
        highest_td_peer = max(
            [cast(ETHPeer, peer) for peer in self.peer_pool.peers],
            key=operator.attrgetter('head_td'))
        self._sync_requests.put_nowait(highest_td_peer)

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break

            pending_msgs = peer.sub_proto_msg_queue.qsize()
            if pending_msgs:
                self.logger.debug(
                    "Read %s msg from %s's queue; %d msgs pending", cmd, peer,
                    pending_msgs)

            # Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main
            # loop can keep processing msgs, and that's why we use ensure_future() instead of
            # awaiting for it to finish here.
            asyncio.ensure_future(self.handle_msg(peer, cmd, msg))

    async def handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                         msg: protocol._DecodedMsgType) -> None:
        try:
            await self._handle_msg(peer, cmd, msg)
        except OperationCancelled:
            # Silently swallow OperationCancelled exceptions because we run unsupervised (i.e.
            # with ensure_future()). Our caller will also get an OperationCancelled anyway, and
            # there it will be handled.
            pass
        except Exception:
            self.logger.exception(
                "Unexpected error when processing msg from %s", peer)

    async def run(self) -> None:
        while True:
            try:
                peer = await wait_with_token(self._sync_requests.get(),
                                             token=self.cancel_token)
            except OperationCancelled:
                break

            asyncio.ensure_future(self.sync(peer))

            # TODO: If we're in light mode and we've synced up to head - 1024, wait until there's
            # no more pending bodies/receipts, trigger cancel token to stop and raise an exception
            # to tell our caller it should perform a state sync.

    async def sync(self, peer: ETHPeer) -> None:
        if self._syncing:
            self.logger.debug(
                "Got a NewBlock or a new peer, but already syncing so doing nothing"
            )
            return
        elif len(self._running_peers) < self.min_peers_to_sync:
            self.logger.warn(
                "Connected to less peers (%d) than the minimum (%d) required to sync, "
                "doing nothing", len(self._running_peers),
                self.min_peers_to_sync)
            return

        self._syncing = True
        try:
            await self._sync(peer)
        finally:
            self._syncing = False

    async def _sync(self, peer: ETHPeer) -> None:
        head = await self.chaindb.coro_get_canonical_head()
        head_td = await self.chaindb.coro_get_score(head.hash)
        if peer.head_td <= head_td:
            self.logger.info(
                "Head TD (%d) announced by %s not higher than ours (%d), not syncing",
                peer.head_td, peer, head_td)
            return

        # FIXME: Fetch a batch of headers, in reverse order, starting from our current head, and
        # find the common ancestor between our chain and the peer's.
        start_at = max(0, head.block_number - eth.MAX_HEADERS_FETCH)
        self.logger.debug("Asking %s for header batch starting at %d", peer,
                          start_at)
        peer.sub_proto.send_get_block_headers(start_at,
                                              eth.MAX_HEADERS_FETCH,
                                              reverse=False)
        while True:
            # TODO: Consider stalling header fetching if there are more than X blocks/receipts
            # pending, to avoid timeouts caused by us not being able to process (decode/store)
            # blocks/receipts fast enough.
            try:
                headers = await wait_with_token(self._new_headers.get(),
                                                peer.wait_until_finished(),
                                                token=self.cancel_token,
                                                timeout=self._reply_timeout)
            except OperationCancelled:
                break
            except TimeoutError:
                self.logger.warn(
                    "Timeout waiting for header batch from %s, aborting sync",
                    peer)
                await peer.stop()
                break

            if peer.is_finished():
                self.logger.info("%s disconnected, aborting sync", peer)
                break

            # TODO: Process headers for consistency.
            for header in headers:
                await self.chaindb.coro_persist_header(header)
                start_at = header.block_number + 1

            self._body_requests.put_nowait(headers)
            self._receipt_requests.put_nowait(headers)

            self.logger.debug("Asking %s for header batch starting at %d",
                              peer, start_at)
            # TODO: Instead of requesting sequential batches from a single peer, request a header
            # skeleton and make concurrent requests, using as many peers as possible, to fill the
            # skeleton.
            peer.sub_proto.send_get_block_headers(start_at,
                                                  eth.MAX_HEADERS_FETCH,
                                                  reverse=False)

    async def _downloader(self, queue: 'asyncio.Queue[List[BlockHeader]]',
                          filter_func: Callable[[List[BlockHeader]],
                                                List[BlockHeader]],
                          request_func: Callable[[List[BlockHeader]],
                                                 Awaitable[None]],
                          pending: Dict[Any, Tuple[BlockHeader, float]],
                          batch_size: int, part_name: str) -> None:
        batch = []  # type: List[BlockHeader]
        while True:
            try:
                headers = await wait_with_token(
                    queue.get(),
                    token=self.cancel_token,
                    # Use a shorter timeout here because this only causes the actual retry
                    # coroutine (self._retry_timedout) to go through the pending ones and retry
                    # any items that are pending for more than self._reply_timeout seconds.
                    timeout=self._reply_timeout / 2)
                batch.extend(headers)
            except TimeoutError:
                # We use a timeout above to make sure we periodically retry timedout items
                # even when there's no new items coming through.
                pass
            except OperationCancelled:
                return
            else:
                # Re-apply the filter function on all items because one of the things it may do is
                # drop items that have the same receipt_root.
                batch = filter_func(batch)
                if len(batch) >= batch_size:
                    await request_func(batch[:batch_size])
                    batch = batch[batch_size:]

            await self._retry_timedout(request_func, pending, batch_size,
                                       part_name)

    async def _retry_timedout(self, request_func: Callable[[List[BlockHeader]],
                                                           Awaitable[None]],
                              pending: Dict[Any, Tuple[BlockHeader, float]],
                              batch_size: int, part_name: str) -> None:
        now = time.time()
        timed_out = [
            header for header, req_time in pending.values()
            if now - req_time > self._reply_timeout
        ]
        while timed_out:
            self.logger.warn(
                "Re-requesting %d timed out %s out of %d pending ones",
                len(timed_out), part_name, len(pending))
            await request_func(timed_out[:batch_size])
            timed_out = timed_out[batch_size:]

    async def body_downloader(self) -> None:
        await self._downloader(self._body_requests, self._skip_empty_bodies,
                               self.request_bodies, self._pending_bodies,
                               eth.MAX_BODIES_FETCH, 'bodies')

    async def receipt_downloader(self) -> None:
        await self._downloader(self._receipt_requests,
                               self._skip_empty_and_duplicated_receipts,
                               self.request_receipts, self._pending_receipts,
                               eth.MAX_RECEIPTS_FETCH, 'receipts')

    @to_list
    def _skip_empty_bodies(
            self,
            headers: List[BlockHeader]) -> Generator[BlockHeader, None, None]:
        for header in headers:
            if (header.transaction_root != self.chaindb.empty_root_hash
                    or header.uncles_hash != EMPTY_UNCLE_HASH):
                yield header

    async def request_bodies(self, headers: List[BlockHeader]) -> None:
        peer = await self.get_idle_peer()
        peer.sub_proto.send_get_block_bodies(
            [header.hash for header in headers])
        self._peers_with_pending_requests.add(peer)
        self.logger.debug("Requesting %d block bodies to %s", len(headers),
                          peer)
        now = time.time()
        for header in headers:
            key = (header.transaction_root, header.uncles_hash)
            self._pending_bodies[key] = (header, now)

    @to_list
    def _skip_empty_and_duplicated_receipts(
            self,
            headers: List[BlockHeader]) -> Generator[BlockHeader, None, None]:
        # Post-Byzantium blocks may have identical receipt roots (e.g. when they have the same
        # number of transactions and all succeed/failed: ropsten blocks 2503212 and 2503284), so
        # we have an extra check here to avoid requesting those receipts multiple times.
        headers = list(unique(headers,
                              key=operator.attrgetter('receipt_root')))
        for header in headers:
            if (header.receipt_root != self.chaindb.empty_root_hash
                    and header.receipt_root not in self._pending_receipts):
                yield header

    async def request_receipts(self, headers: List[BlockHeader]) -> None:
        peer = await self.get_idle_peer()
        peer.sub_proto.send_get_receipts([header.hash for header in headers])
        self._peers_with_pending_requests.add(peer)
        self.logger.debug("Requesting %d block receipts to %s", len(headers),
                          peer)
        now = time.time()
        for header in headers:
            self._pending_receipts[header.receipt_root] = (header, now)

    async def get_idle_peer(self) -> ETHPeer:
        """Return a random peer which we're not already expecting a response from."""
        while True:
            idle_peers = [
                peer for peer in self.peer_pool.peers
                if peer not in self._peers_with_pending_requests
            ]
            if idle_peers:
                return cast(ETHPeer, secrets.choice(idle_peers))
            else:
                self.logger.debug("No idle peers availabe, sleeping a bit")
                await asyncio.sleep(0.2)

    async def wait_until_finished(self) -> None:
        start_at = time.time()
        # Wait at most 5 seconds for pending peers to finish.
        self.logger.info("Waiting for %d running peers to finish",
                         len(self._running_peers))
        while time.time() < start_at + 5:
            if not self._running_peers:
                break
            await asyncio.sleep(0.1)
        else:
            self.logger.info(
                "Waited too long for peers to finish, exiting anyway")

    async def stop(self) -> None:
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        await self.wait_until_finished()

    async def _handle_msg(self, peer: ETHPeer, cmd: protocol.Command,
                          msg: protocol._DecodedMsgType) -> None:
        loop = asyncio.get_event_loop()
        if isinstance(cmd, eth.BlockHeaders):
            msg = cast(List[BlockHeader], msg)
            self.logger.debug("Got BlockHeaders from %d to %d",
                              msg[0].block_number, msg[-1].block_number)
            self._new_headers.put_nowait(msg)
        elif isinstance(cmd, eth.BlockBodies):
            self._peers_with_pending_requests.remove(peer)
            msg = cast(List[eth.BlockBody], msg)
            self.logger.debug("Got %d BlockBodies from %s", len(msg), peer)
            iterator = map(make_trie_root_and_nodes,
                           [body.transactions for body in msg])
            transactions_tries = await wait_with_token(loop.run_in_executor(
                None, list, iterator),
                                                       token=self.cancel_token)
            for i in range(len(msg)):
                body = msg[i]
                tx_root, trie_dict_data = transactions_tries[i]
                await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
                # TODO: Add transactions to canonical chain; blocked by
                # https://github.com/ethereum/py-evm/issues/337
                uncles_hash = await self.chaindb.coro_persist_uncles(
                    body.uncles)
                self._pending_bodies.pop((tx_root, uncles_hash), None)
        elif isinstance(cmd, eth.Receipts):
            self._peers_with_pending_requests.remove(peer)
            msg = cast(List[List[eth.Receipt]], msg)
            self.logger.debug("Got Receipts for %d blocks from %s", len(msg),
                              peer)
            iterator = map(make_trie_root_and_nodes, msg)
            receipts_tries = await wait_with_token(loop.run_in_executor(
                None, list, iterator),
                                                   token=self.cancel_token)
            for receipt_root, trie_dict_data in receipts_tries:
                if receipt_root not in self._pending_receipts:
                    self.logger.warning(
                        "Got unexpected receipt root: %s",
                        encode_hex(receipt_root),
                    )
                    continue
                await self.chaindb.coro_persist_trie_data_dict(trie_dict_data)
                self._pending_receipts.pop(receipt_root)
        elif isinstance(cmd, eth.NewBlock):
            msg = cast(Dict[str, Any], msg)
            header = msg['block'][0]
            actual_head = header.parent_hash
            actual_td = msg['total_difficulty'] - header.difficulty
            if actual_td > peer.head_td:
                peer.head_hash = actual_head
                peer.head_td = actual_td
                self._sync_requests.put_nowait(peer)
        elif isinstance(cmd, eth.Transactions):
            # TODO: Figure out what to do with those.
            pass
        elif isinstance(cmd, eth.NewBlockHashes):
            # TODO: Figure out what to do with those.
            pass
        else:
            # TODO: There are other msg types we'll want to handle here, but for now just log them
            # as a warning so we don't forget about it.
            self.logger.warn("Got unexpected msg: %s (%s)", cmd, msg)
Ejemplo n.º 27
0
class StateDownloader(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.state.StateDownloader")
    _pending_nodes = {}  # type: Dict[Any, float]
    _total_processed_nodes = 0
    _report_interval = 10  # Number of seconds between progress reports.
    # TODO: Experiment with different timeout/max_pending values to find the combination that
    # yields the best results.
    # FIXME: Should use the # of peers times MAX_STATE_FETCH here
    _max_pending = 5 * MAX_STATE_FETCH
    _reply_timeout = 10  # seconds
    # For simplicity/readability we use 0 here to force a report on the first iteration of the
    # loop.
    _last_report_time = 0

    def __init__(self, state_db: BaseDB, root_hash: bytes, peer_pool: PeerPool) -> None:
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.root_hash = root_hash
        self.scheduler = StateSync(root_hash, state_db, self.logger)
        self._running_peers = set()  # type: Set[ETHPeer]
        self.cancel_token = CancelToken('StateDownloader')

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break
            if isinstance(cmd, eth.NodeData):
                self.logger.debug("Processing NodeData with %d entries", len(msg))
                for node in msg:
                    self._total_processed_nodes += 1
                    node_key = keccak(node)
                    try:
                        self.scheduler.process([(node_key, node)])
                    except SyncRequestAlreadyProcessed:
                        # This means we received a node more than once, which can happen when we
                        # retry after a timeout.
                        pass
                    # A node may be received more than once, so pop() with a default value.
                    self._pending_nodes.pop(node_key, None)
            else:
                # It'd be very convenient if we could ignore everything that is not a NodeData
                # when doing a StateSync, but need to double check because peers may consider that
                # "Bad Form" and disconnect from us.
                self.logger.debug("Ignoring %s(%s) while doing a StateSync", cmd, msg)

    # FIXME: Need a better criteria to select peers here.
    async def get_random_peer(self) -> ETHPeer:
        while not self.peer_pool.peers:
            self.logger.debug("No connected peers, sleeping a bit")
            await asyncio.sleep(0.5)
        peer = random.choice(self.peer_pool.peers)
        return cast(ETHPeer, peer)

    async def stop(self):
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        while self._running_peers:
            self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers))
            await asyncio.sleep(0.1)

    async def request_next_batch(self):
        requests = self.scheduler.next_batch(MAX_STATE_FETCH)
        if not requests:
            # Although our run() loop frequently yields control to let our msg handler process
            # received nodes (scheduling new requests), there may be cases when the pending nodes
            # take a while to arrive thus causing the scheduler to run out of new requests for a
            # while.
            self.logger.debug("Scheduler queue is empty, not requesting any nodes")
            return
        self.logger.debug("Requesting %d trie nodes", len(requests))
        await self.request_nodes([request.node_key for request in requests])

    async def request_nodes(self, node_keys: List[bytes]) -> None:
        peer = await self.get_random_peer()
        now = time.time()
        for node_key in node_keys:
            self._pending_nodes[node_key] = now
        peer.sub_proto.send_get_node_data(node_keys)

    async def retry_timedout(self):
        timed_out = []
        now = time.time()
        for node_key, req_time in list(self._pending_nodes.items()):
            if now - req_time > self._reply_timeout:
                timed_out.append(node_key)
        if not timed_out:
            return
        self.logger.debug("Re-requesting %d trie nodes", len(timed_out))
        await self.request_nodes(timed_out)

    async def run(self):
        self.logger.info("Starting state sync for root hash %s", encode_hex(self.root_hash))
        while self.scheduler.has_pending_requests and not self.cancel_token.triggered:
            # Request new nodes if we haven't reached the limit of pending nodes.
            if len(self._pending_nodes) < self._max_pending:
                await self.request_next_batch()

            # Retry pending nodes that timed out.
            if self._pending_nodes:
                await self.retry_timedout()

            if len(self._pending_nodes) > self._max_pending:
                # Slow down if we've reached the limit of pending nodes.
                self.logger.debug("Pending trie nodes limit reached, sleeping a bit")
                await asyncio.sleep(0.3)
            else:
                # Yield control to ensure the Peer's msg_handler callback is called to process any
                # nodes we may have received already. Otherwise we spin too fast and don't process
                # received nodes often enough.
                await asyncio.sleep(0)

            self._maybe_report_progress()

        self.logger.info("Finished state sync with root hash %s", encode_hex(self.root_hash))

    def _maybe_report_progress(self):
        if (time.time() - self._last_report_time) >= self._report_interval:
            self._last_report_time = time.time()
            self.logger.info("Nodes processed: %d", self._total_processed_nodes)
            self.logger.info(
                "Nodes requested but not received yet: %d", len(self._pending_nodes))
            self.logger.info(
                "Nodes scheduled but not requested yet: %d", len(self.scheduler.requests))
Ejemplo n.º 28
0
class StateDownloader(PeerPoolSubscriber):
    logger = logging.getLogger("p2p.state.StateDownloader")
    _pending_nodes = {}  # type: Dict[Any, float]
    _total_processed_nodes = 0
    _report_interval = 10  # Number of seconds between progress reports.
    # TODO: Experiment with different timeout/max_pending values to find the combination that
    # yields the best results.
    # FIXME: Should use the # of peers times MAX_STATE_FETCH here
    _max_pending = 5 * MAX_STATE_FETCH
    _reply_timeout = 10  # seconds
    # For simplicity/readability we use 0 here to force a report on the first iteration of the
    # loop.
    _last_report_time = 0

    def __init__(self, state_db: BaseDB, root_hash: bytes, peer_pool: PeerPool) -> None:
        self.peer_pool = peer_pool
        self.peer_pool.subscribe(self)
        self.root_hash = root_hash
        self.scheduler = StateSync(root_hash, state_db, self.logger)
        self._running_peers = set()  # type: Set[ETHPeer]
        self.cancel_token = CancelToken('StateDownloader')

    def register_peer(self, peer: BasePeer) -> None:
        asyncio.ensure_future(self.handle_peer(cast(ETHPeer, peer)))

    async def handle_peer(self, peer: ETHPeer) -> None:
        """Handle the lifecycle of the given peer."""
        self._running_peers.add(peer)
        try:
            await self._handle_peer(peer)
        finally:
            self._running_peers.remove(peer)

    async def _handle_peer(self, peer: ETHPeer) -> None:
        while True:
            try:
                cmd, msg = await peer.read_sub_proto_msg(self.cancel_token)
            except OperationCancelled:
                # Either our cancel token or the peer's has been triggered, so break out of the
                # loop.
                break
            if isinstance(cmd, eth.NodeData):
                self.logger.debug("Processing NodeData with %d entries", len(msg))
                for node in msg:
                    self._total_processed_nodes += 1
                    node_key = keccak(node)
                    try:
                        self.scheduler.process([(node_key, node)])
                    except SyncRequestAlreadyProcessed:
                        # This means we received a node more than once, which can happen when we
                        # retry after a timeout.
                        pass
                    # A node may be received more than once, so pop() with a default value.
                    self._pending_nodes.pop(node_key, None)
            else:
                # It'd be very convenient if we could ignore everything that is not a NodeData
                # when doing a StateSync, but need to double check because peers may consider that
                # "Bad Form" and disconnect from us.
                self.logger.debug("Ignoring %s(%s) while doing a StateSync", cmd, msg)

    # FIXME: Need a better criteria to select peers here.
    async def get_random_peer(self) -> ETHPeer:
        while not self.peer_pool.peers:
            self.logger.debug("No connected peers, sleeping a bit")
            await asyncio.sleep(0.5)
        peer = random.choice(self.peer_pool.peers)
        return cast(ETHPeer, peer)

    async def stop(self):
        self.cancel_token.trigger()
        self.peer_pool.unsubscribe(self)
        while self._running_peers:
            self.logger.debug("Waiting for %d running peers to finish", len(self._running_peers))
            await asyncio.sleep(0.1)

    async def request_next_batch(self):
        requests = self.scheduler.next_batch(MAX_STATE_FETCH)
        if not requests:
            # Although our run() loop frequently yields control to let our msg handler process
            # received nodes (scheduling new requests), there may be cases when the pending nodes
            # take a while to arrive thus causing the scheduler to run out of new requests for a
            # while.
            self.logger.debug("Scheduler queue is empty, not requesting any nodes")
            return
        self.logger.debug("Requesting %d trie nodes", len(requests))
        await self.request_nodes([request.node_key for request in requests])

    async def request_nodes(self, node_keys: List[bytes]) -> None:
        peer = await self.get_random_peer()
        now = time.time()
        for node_key in node_keys:
            self._pending_nodes[node_key] = now
        peer.sub_proto.send_get_node_data(node_keys)

    async def retry_timedout(self):
        timed_out = []
        now = time.time()
        for node_key, req_time in list(self._pending_nodes.items()):
            if now - req_time > self._reply_timeout:
                timed_out.append(node_key)
        if not timed_out:
            return
        self.logger.debug("Re-requesting %d trie nodes", len(timed_out))
        await self.request_nodes(timed_out)

    async def run(self):
        self.logger.info("Starting state sync for root hash %s", encode_hex(self.root_hash))
        while self.scheduler.has_pending_requests and not self.cancel_token.triggered:
            # Request new nodes if we haven't reached the limit of pending nodes.
            if len(self._pending_nodes) < self._max_pending:
                await self.request_next_batch()

            # Retry pending nodes that timed out.
            if self._pending_nodes:
                await self.retry_timedout()

            if len(self._pending_nodes) > self._max_pending:
                # Slow down if we've reached the limit of pending nodes.
                self.logger.debug("Pending trie nodes limit reached, sleeping a bit")
                await asyncio.sleep(0.3)
            else:
                # Yield control to ensure the Peer's msg_handler callback is called to process any
                # nodes we may have received already. Otherwise we spin too fast and don't process
                # received nodes often enough.
                await asyncio.sleep(0)

            self._maybe_report_progress()

        self.logger.info("Finished state sync with root hash %s", encode_hex(self.root_hash))

    def _maybe_report_progress(self):
        if (time.time() - self._last_report_time) >= self._report_interval:
            self._last_report_time = time.time()
            self.logger.info("Nodes processed: %d", self._total_processed_nodes)
            self.logger.info(
                "Nodes requested but not received yet: %d", len(self._pending_nodes))
            self.logger.info(
                "Nodes scheduled but not requested yet: %d", len(self.scheduler.requests))
Ejemplo n.º 29
0
def test_token_single():
    token = CancelToken('token')
    assert not token.triggered
    token.trigger()
    assert token.triggered
    assert token.triggered_token == token
Ejemplo n.º 30
0
class Server:
    """Server listening for incoming connections"""
    logger = logging.getLogger("p2p.server.Server")
    _server = None

    def __init__(self, privkey: datatypes.PrivateKey, server_address: Address,
                 chaindb: AsyncChainDB, bootstrap_nodes: List[str],
                 network_id: int) -> None:
        self.cancel_token = CancelToken('Server')
        self.chaindb = chaindb
        self.privkey = privkey
        self.server_address = server_address
        self.network_id = network_id
        # TODO: bootstrap_nodes should be looked up by network_id.
        discovery = DiscoveryProtocol(self.privkey,
                                      self.server_address,
                                      bootstrap_nodes=bootstrap_nodes)
        self.peer_pool = PeerPool(ETHPeer, self.chaindb, self.network_id,
                                  self.privkey, discovery)

    async def start(self) -> None:
        self._server = await asyncio.start_server(
            self.receive_handshake,
            host=self.server_address.ip,
            port=self.server_address.tcp_port,
        )

    async def run(self) -> None:
        await self.start()
        self.logger.info("Running server...")
        await self.cancel_token.wait()
        await self.stop()

    async def stop(self) -> None:
        self.logger.info("Closing server...")
        self.cancel_token.trigger()
        self._server.close()
        await self._server.wait_closed()
        await self.peer_pool.stop()

    async def receive_handshake(self, reader: asyncio.StreamReader,
                                writer: asyncio.StreamWriter) -> None:
        # Use reader to read the auth_init msg until EOF
        msg = await reader.read(ENCRYPTED_AUTH_MSG_LEN)

        # Use HandshakeResponder.decode_authentication(auth_init_message) on auth init msg
        try:
            ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                msg, self.privkey)
        # Try to decode as EIP8
        except DecryptionError:
            msg_size = big_endian_to_int(msg[:2])
            remaining_bytes = msg_size - ENCRYPTED_AUTH_MSG_LEN + 2
            msg += await reader.read(remaining_bytes)
            ephem_pubkey, initiator_nonce, initiator_pubkey = decode_authentication(
                msg, self.privkey)

        # Get remote's address: IPv4 or IPv6
        ip, port, *_ = writer.get_extra_info("peername")
        remote_address = Address(ip, port)

        # Create `HandshakeResponder(remote: kademlia.Node, privkey: datatypes.PrivateKey)` instance
        initiator_remote = Node(initiator_pubkey, remote_address)
        responder = HandshakeResponder(initiator_remote, self.privkey)

        # Call `HandshakeResponder.create_auth_ack_message(nonce: bytes)` to create the reply
        responder_nonce = secrets.token_bytes(HASH_LEN)
        auth_ack_msg = responder.create_auth_ack_message(nonce=responder_nonce)
        auth_ack_ciphertext = responder.encrypt_auth_ack_message(auth_ack_msg)

        # Use the `writer` to send the reply to the remote
        writer.write(auth_ack_ciphertext)
        await writer.drain()

        # Call `HandshakeResponder.derive_shared_secrets()` and use return values to create `Peer`
        aes_secret, mac_secret, egress_mac, ingress_mac = responder.derive_secrets(
            initiator_nonce=initiator_nonce,
            responder_nonce=responder_nonce,
            remote_ephemeral_pubkey=ephem_pubkey,
            auth_init_ciphertext=msg,
            auth_ack_ciphertext=auth_ack_ciphertext)

        # Create and register peer in peer_pool
        eth_peer = ETHPeer(remote=initiator_remote,
                           privkey=self.privkey,
                           reader=reader,
                           writer=writer,
                           aes_secret=aes_secret,
                           mac_secret=mac_secret,
                           egress_mac=egress_mac,
                           ingress_mac=ingress_mac,
                           chaindb=self.chaindb,
                           network_id=self.network_id)
        self.peer_pool.add_peer(eth_peer)