Exemple #1
0
class MerkleCache:
    '''A cache to calculate merkle branches efficiently.'''
    def __init__(self, merkle, source_func):
        '''Initialise a cache hashes taken from source_func:

           async def source_func(index, count):
              ...
        '''
        self.merkle = merkle
        self.source_func = source_func
        self.length = 0
        self.level = []
        self.depth_higher = 0
        self.initialized = Event()

    def _segment_length(self):
        return 1 << self.depth_higher

    def _leaf_start(self, index):
        '''Given a level's depth higher and a hash index, return the leaf
        index and leaf hash count needed to calculate a merkle branch.
        '''
        depth_higher = self.depth_higher
        return (index >> depth_higher) << depth_higher

    def _level(self, hashes):
        return self.merkle.level(hashes, self.depth_higher)

    async def _extend_to(self, length):
        '''Extend the length of the cache if necessary.'''
        if length <= self.length:
            return
        # Start from the beginning of any final partial segment.
        # Retain the value of depth_higher; in practice this is fine
        start = self._leaf_start(self.length)
        hashes = await self.source_func(start, length - start)
        self.level[start >> self.depth_higher:] = self._level(hashes)
        self.length = length

    async def _level_for(self, length):
        '''Return a (level_length, final_hash) pair for a truncation
        of the hashes to the given length.'''
        if length == self.length:
            return self.level
        level = self.level[:length >> self.depth_higher]
        leaf_start = self._leaf_start(length)
        count = min(self._segment_length(), length - leaf_start)
        hashes = await self.source_func(leaf_start, count)
        level += self._level(hashes)
        return level

    async def initialize(self, length):
        '''Call to initialize the cache to a source of given length.'''
        self.length = length
        self.depth_higher = self.merkle.tree_depth(length) // 2
        self.level = self._level(await self.source_func(0, length))
        self.initialized.set()

    def truncate(self, length):
        '''Truncate the cache so it covers no more than length underlying
        hashes.'''
        if not isinstance(length, int):
            raise TypeError('length must be an integer')
        if length <= 0:
            raise ValueError('length must be positive')
        if length >= self.length:
            return
        length = self._leaf_start(length)
        self.length = length
        self.level[length >> self.depth_higher:] = []

    async def branch_and_root(self, length, index):
        '''Return a merkle branch and root.  Length is the number of
        hashes used to calculate the merkle root, index is the position
        of the hash to calculate the branch of.

        index must be less than length, which must be at least 1.'''
        if not isinstance(length, int):
            raise TypeError('length must be an integer')
        if not isinstance(index, int):
            raise TypeError('index must be an integer')
        if length <= 0:
            raise ValueError('length must be positive')
        if index >= length:
            raise ValueError('index must be less than length')
        await self.initialized.wait()
        await self._extend_to(length)
        leaf_start = self._leaf_start(index)
        count = min(self._segment_length(), length - leaf_start)
        leaf_hashes = await self.source_func(leaf_start, count)
        if length < self._segment_length():
            return self.merkle.branch_and_root(leaf_hashes, index)
        level = await self._level_for(length)
        return self.merkle.branch_and_root_from_level(level, leaf_hashes,
                                                      index, self.depth_higher)
class SessionManager(object):
    '''Holds global state about all sessions.'''

    def __init__(self, env, db, bp, daemon, mempool, shutdown_event):
        env.max_send = max(350000, env.max_send)
        self.env = env
        self.db = db
        self.bp = bp
        self.daemon = daemon
        self.mempool = mempool
        self.peer_mgr = PeerManager(env, db)
        self.shutdown_event = shutdown_event
        self.logger = util.class_logger(__name__, self.__class__.__name__)
        self.servers = {}
        self.sessions = set()
        self.max_subs = env.max_subs
        self.cur_group = SessionGroup(0)
        self.txs_sent = 0
        self.start_time = time.time()
        self.history_cache = pylru.lrucache(256)
        self.notified_height = None
        # Cache some idea of room to avoid recounting on each subscription
        self.subs_room = 0
        # Event triggered when electrumx is listening for incoming requests.
        self.server_listening = Event()
        self.session_event = Event()

        # Set up the RPC request handlers
        cmds = ('add_peer daemon_url disconnect getinfo groups log peers '
                'query reorg sessions stop'.split())
        LocalRPC.request_handlers = {cmd: getattr(self, 'rpc_' + cmd)
                                     for cmd in cmds}

    async def _start_server(self, kind, *args, **kw_args):
        loop = asyncio.get_event_loop()
        if kind == 'RPC':
            protocol_class = LocalRPC
        else:
            protocol_class = self.env.coin.SESSIONCLS
        protocol_factory = partial(protocol_class, self, self.db,
                                   self.mempool, self.peer_mgr, kind)
        server = loop.create_server(protocol_factory, *args, **kw_args)

        host, port = args[:2]
        try:
            self.servers[kind] = await server
        except OSError as e:    # don't suppress CancelledError
            self.logger.error(f'{kind} server failed to listen on {host}:'
                              f'{port:d} :{e!r}')
        else:
            self.logger.info(f'{kind} server listening on {host}:{port:d}')

    async def _start_external_servers(self):
        '''Start listening on TCP and SSL ports, but only if the respective
        port was given in the environment.
        '''
        env = self.env
        host = env.cs_host(for_rpc=False)
        if env.tcp_port is not None:
            await self._start_server('TCP', host, env.tcp_port)
        if env.ssl_port is not None:
            sslc = ssl.SSLContext(ssl.PROTOCOL_TLS)
            sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile)
            await self._start_server('SSL', host, env.ssl_port, ssl=sslc)
        self.server_listening.set()

    async def _close_servers(self, kinds):
        '''Close the servers of the given kinds (TCP etc.).'''
        if kinds:
            self.logger.info('closing down {} listening servers'
                             .format(', '.join(kinds)))
        for kind in kinds:
            server = self.servers.pop(kind, None)
            if server:
                server.close()
                await server.wait_closed()

    async def _manage_servers(self):
        paused = False
        max_sessions = self.env.max_sessions
        low_watermark = max_sessions * 19 // 20
        while True:
            await self.session_event.wait()
            self.session_event.clear()
            if not paused and len(self.sessions) >= max_sessions:
                self.logger.info(f'maximum sessions {max_sessions:,d} '
                                 f'reached, stopping new connections until '
                                 f'count drops to {low_watermark:,d}')
                await self._close_servers(['TCP', 'SSL'])
                paused = True
            # Start listening for incoming connections if paused and
            # session count has fallen
            if paused and len(self.sessions) <= low_watermark:
                self.logger.info('resuming listening for incoming connections')
                await self._start_external_servers()
                paused = False

    async def _log_sessions(self):
        '''Periodically log sessions.'''
        log_interval = self.env.log_sessions
        if log_interval:
            while True:
                await sleep(log_interval)
                data = self._session_data(for_log=True)
                for line in text.sessions_lines(data):
                    self.logger.info(line)
                self.logger.info(json.dumps(self._get_info()))

    def _group_map(self):
        group_map = defaultdict(list)
        for session in self.sessions:
            group_map[session.group].append(session)
        return group_map

    def _sub_count(self):
        return sum(s.sub_count() for s in self.sessions)

    def _lookup_session(self, session_id):
        try:
            session_id = int(session_id)
        except Exception:
            pass
        else:
            for session in self.sessions:
                if session.session_id == session_id:
                    return session
        return None

    async def _for_each_session(self, session_ids, operation):
        if not isinstance(session_ids, list):
            raise RPCError(BAD_REQUEST, 'expected a list of session IDs')

        result = []
        for session_id in session_ids:
            session = self._lookup_session(session_id)
            if session:
                result.append(await operation(session))
            else:
                result.append(f'unknown session: {session_id}')
        return result

    async def _clear_stale_sessions(self):
        '''Cut off sessions that haven't done anything for 10 minutes.'''
        while True:
            await sleep(60)
            stale_cutoff = time.time() - self.env.session_timeout
            stale_sessions = [session for session in self.sessions
                              if session.last_recv < stale_cutoff]
            if stale_sessions:
                text = ', '.join(str(session.session_id)
                                 for session in stale_sessions)
                self.logger.info(f'closing stale connections {text}')
                # Give the sockets some time to close gracefully
                for session in stale_sessions:
                    await session.spawn(session.close())

            # Consolidate small groups
            bw_limit = self.env.bandwidth_limit
            group_map = self._group_map()
            groups = [group for group, sessions in group_map.items()
                      if len(sessions) <= 5 and
                      sum(s.bw_charge for s in sessions) < bw_limit]
            if len(groups) > 1:
                new_group = groups[-1]
                for group in groups:
                    for session in group_map[group]:
                        session.group = new_group

    def _get_info(self):
        '''A summary of server state.'''
        group_map = self._group_map()
        return {
            'closing': len([s for s in self.sessions if s.is_closing()]),
            'daemon': self.daemon.logged_url(),
            'daemon_height': self.daemon.cached_height(),
            'db_height': self.db.db_height,
            'errors': sum(s.errors for s in self.sessions),
            'groups': len(group_map),
            'logged': len([s for s in self.sessions if s.log_me]),
            'paused': sum(not s._can_send.is_set() for s in self.sessions),
            'pid': os.getpid(),
            'peers': self.peer_mgr.info(),
            'requests': sum(s.count_pending_items() for s in self.sessions),
            'sessions': self.session_count(),
            'subs': self._sub_count(),
            'txs_sent': self.txs_sent,
            'uptime': util.formatted_time(time.time() - self.start_time),
            'version': electrumx.version,
        }

    def _session_data(self, for_log):
        '''Returned to the RPC 'sessions' call.'''
        now = time.time()
        sessions = sorted(self.sessions, key=lambda s: s.start_time)
        return [(session.session_id,
                 session.flags(),
                 session.peer_address_str(for_log=for_log),
                 session.client,
                 session.protocol_version_string(),
                 session.count_pending_items(),
                 session.txs_sent,
                 session.sub_count(),
                 session.recv_count, session.recv_size,
                 session.send_count, session.send_size,
                 now - session.start_time)
                for session in sessions]

    def _group_data(self):
        '''Returned to the RPC 'groups' call.'''
        result = []
        group_map = self._group_map()
        for group, sessions in group_map.items():
            result.append([group.gid,
                           len(sessions),
                           sum(s.bw_charge for s in sessions),
                           sum(s.count_pending_items() for s in sessions),
                           sum(s.txs_sent for s in sessions),
                           sum(s.sub_count() for s in sessions),
                           sum(s.recv_count for s in sessions),
                           sum(s.recv_size for s in sessions),
                           sum(s.send_count for s in sessions),
                           sum(s.send_size for s in sessions),
                           ])
        return result

    async def _electrum_and_raw_headers(self, height):
        raw_header = await self.raw_header(height)
        electrum_header = self.env.coin.electrum_header(raw_header, height)
        return electrum_header, raw_header

    async def _refresh_hsub_results(self, height):
        '''Refresh the cached header subscription responses to be for height,
        and record that as notified_height.
        '''
        # Paranoia: a reorg could race and leave db_height lower
        height = min(height, self.db.db_height)
        electrum, raw = await self._electrum_and_raw_headers(height)
        self.hsub_results = (electrum, {'hex': raw.hex(), 'height': height})
        self.notified_height = height

    # --- LocalRPC command handlers

    async def rpc_add_peer(self, real_name):
        '''Add a peer.

        real_name: "btc.electrumx.org t50001 s50002" for example
        '''
        await self.peer_mgr.add_localRPC_peer(real_name)
        return "peer '{}' added".format(real_name)

    async def rpc_disconnect(self, session_ids):
        '''Disconnect sesssions.

        session_ids: array of session IDs
        '''
        async def close(session):
            '''Close the session's transport.'''
            await session.close(force_after=2)
            return f'disconnected {session.session_id}'

        return await self._for_each_session(session_ids, close)

    async def rpc_log(self, session_ids):
        '''Toggle logging of sesssions.

        session_ids: array of session IDs
        '''
        async def toggle_logging(session):
            '''Toggle logging of the session.'''
            session.toggle_logging()
            return f'log {session.session_id}: {session.log_me}'

        return await self._for_each_session(session_ids, toggle_logging)

    async def rpc_daemon_url(self, daemon_url):
        '''Replace the daemon URL.'''
        daemon_url = daemon_url or self.env.daemon_url
        try:
            self.daemon.set_url(daemon_url)
        except Exception as e:
            raise RPCError(BAD_REQUEST, f'an error occured: {e!r}')
        return f'now using daemon at {self.daemon.logged_url()}'

    async def rpc_stop(self):
        '''Shut down the server cleanly.'''
        self.shutdown_event.set()
        return 'stopping'

    async def rpc_getinfo(self):
        '''Return summary information about the server process.'''
        return self._get_info()

    async def rpc_groups(self):
        '''Return statistics about the session groups.'''
        return self._group_data()

    async def rpc_peers(self):
        '''Return a list of data about server peers.'''
        return self.peer_mgr.rpc_data()

    async def rpc_query(self, items, limit):
        '''Return a list of data about server peers.'''
        coin = self.env.coin
        db = self.db
        lines = []

        def arg_to_hashX(arg):
            try:
                script = bytes.fromhex(arg)
                lines.append(f'Script: {arg}')
                return coin.hashX_from_script(script)
            except ValueError:
                pass

            try:
                hashX = coin.address_to_hashX(arg)
            except Base58Error as e:
                lines.append(e.args[0])
                return None
            lines.append(f'Address: {arg}')
            return hashX

        for arg in items:
            hashX = arg_to_hashX(arg)
            if not hashX:
                continue
            n = None
            history = await db.limited_history(hashX, limit=limit)
            for n, (tx_hash, height) in enumerate(history):
                lines.append(f'History #{n:,d}: height {height:,d} '
                             f'tx_hash {hash_to_hex_str(tx_hash)}')
            if n is None:
                lines.append('No history found')
            n = None
            utxos = await db.all_utxos(hashX)
            for n, utxo in enumerate(utxos, start=1):
                lines.append(f'UTXO #{n:,d}: tx_hash '
                             f'{hash_to_hex_str(utxo.tx_hash)} '
                             f'tx_pos {utxo.tx_pos:,d} height '
                             f'{utxo.height:,d} value {utxo.value:,d}')
                if n == limit:
                    break
            if n is None:
                lines.append('No UTXOs found')

            balance = sum(utxo.value for utxo in utxos)
            lines.append(f'Balance: {coin.decimal_value(balance):,f} '
                         f'{coin.SHORTNAME}')

        return lines

    async def rpc_sessions(self):
        '''Return statistics about connected sessions.'''
        return self._session_data(for_log=False)

    async def rpc_reorg(self, count):
        '''Force a reorg of the given number of blocks.

        count: number of blocks to reorg
        '''
        count = non_negative_integer(count)
        if not self.bp.force_chain_reorg(count):
            raise RPCError(BAD_REQUEST, 'still catching up with daemon')
        return f'scheduled a reorg of {count:,d} blocks'

    # --- External Interface

    async def serve(self, notifications, event):
        '''Start the RPC server if enabled.  When the event is triggered,
        start TCP and SSL servers.'''
        try:
            if self.env.rpc_port is not None:
                await self._start_server('RPC', self.env.cs_host(for_rpc=True),
                                         self.env.rpc_port)
            await event.wait()
            self.logger.info(f'max session count: {self.env.max_sessions:,d}')
            self.logger.info(f'session timeout: '
                             f'{self.env.session_timeout:,d} seconds')
            self.logger.info('session bandwidth limit {:,d} bytes'
                             .format(self.env.bandwidth_limit))
            self.logger.info('max response size {:,d} bytes'
                             .format(self.env.max_send))
            self.logger.info('max subscriptions across all sessions: {:,d}'
                             .format(self.max_subs))
            self.logger.info('max subscriptions per session: {:,d}'
                             .format(self.env.max_session_subs))
            if self.env.drop_client is not None:
                self.logger.info('drop clients matching: {}'
                                 .format(self.env.drop_client.pattern))
            # Start notifications; initialize hsub_results
            await notifications.start(self.db.db_height, self._notify_sessions)
            await self._start_external_servers()
            # Peer discovery should start after the external servers
            # because we connect to ourself
            async with TaskGroup() as group:
                await group.spawn(self.peer_mgr.discover_peers())
                await group.spawn(self._clear_stale_sessions())
                await group.spawn(self._log_sessions())
                await group.spawn(self._manage_servers())
        finally:
            # Close servers then sessions
            await self._close_servers(list(self.servers.keys()))
            for session in list(self.sessions):
                await session.spawn(session.close(force_after=1))
            for session in list(self.sessions):
                await session.closed_event.wait()

    def session_count(self):
        '''The number of connections that we've sent something to.'''
        return len(self.sessions)

    async def daemon_request(self, method, *args):
        '''Catch a DaemonError and convert it to an RPCError.'''
        try:
            return await getattr(self.daemon, method)(*args)
        except DaemonError as e:
            raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None

    async def raw_header(self, height):
        '''Return the binary header at the given height.'''
        try:
            return await self.db.raw_header(height)
        except IndexError:
            raise RPCError(BAD_REQUEST, f'height {height:,d} '
                           'out of range') from None

    async def electrum_header(self, height):
        '''Return the deserialized header at the given height.'''
        electrum_header, _ = await self._electrum_and_raw_headers(height)
        return electrum_header

    async def broadcast_transaction(self, raw_tx):
        hex_hash = await self.daemon.broadcast_transaction(raw_tx)
        self.txs_sent += 1
        return hex_hash

    async def limited_history(self, hashX):
        '''A caching layer.'''
        hc = self.history_cache
        if hashX not in hc:
            # History DoS limit.  Each element of history is about 99
            # bytes when encoded as JSON.  This limits resource usage
            # on bloated history requests, and uses a smaller divisor
            # so large requests are logged before refusing them.
            limit = self.env.max_send // 97
            hc[hashX] = await self.db.limited_history(hashX, limit=limit)
        return hc[hashX]

    async def _notify_sessions(self, height, touched):
        '''Notify sessions about height changes and touched addresses.'''
        height_changed = height != self.notified_height
        if height_changed:
            await self._refresh_hsub_results(height)
            # Invalidate our history cache for touched hashXs
            hc = self.history_cache
            for hashX in set(hc).intersection(touched):
                del hc[hashX]

        for session in self.sessions:
            await session.spawn(session.notify, touched, height_changed)

    def add_session(self, session):
        self.sessions.add(session)
        self.session_event.set()
        gid = int(session.start_time - self.start_time) // 900
        if self.cur_group.gid != gid:
            self.cur_group = SessionGroup(gid)
        return self.cur_group

    def remove_session(self, session):
        '''Remove a session from our sessions list if there.'''
        self.sessions.remove(session)
        self.session_event.set()

    def new_subscription(self):
        if self.subs_room <= 0:
            self.subs_room = self.max_subs - self._sub_count()
            if self.subs_room <= 0:
                raise RPCError(BAD_REQUEST, f'server subscription limit '
                               f'{self.max_subs:,d} reached')
        self.subs_room -= 1