Пример #1
0
    async def _main(self, loop):
        '''Run the server application:

        - record start time
        - install SIGINT and SIGTERM handlers to trigger shutdown_event
        - set loop's exception handler to suppress unwanted messages
        - run the event loop until serve() completes
        '''
        def on_signal(signame):
            shutdown_event.set()
            self.logger.warning(f'received {signame} signal, '
                                f'initiating shutdown')

        self.start_time = time.time()
        for signame in ('SIGINT', 'SIGTERM'):
            loop.add_signal_handler(getattr(signal, signame),
                                    partial(on_signal, signame))
        loop.set_exception_handler(self.on_exception)

        shutdown_event = asyncio.Event()
        try:
            async with TaskGroup() as group:
                server_task = await group.spawn(self.serve(shutdown_event))
                # Wait for shutdown, log on receipt of the event
                await shutdown_event.wait()
                self.logger.info('shutting down')
                server_task.cancel()
        finally:
            await loop.shutdown_asyncgens()

        # Prevent some silly logs
        await asyncio.sleep(0.001)
        # Finally, work around an apparent asyncio bug that causes log
        # spew on shutdown for partially opened SSL sockets
        try:
            del asyncio.sslproto._SSLProtocolTransport.__del__
        except Exception:
            pass

        self.logger.info('shutdown complete')
Пример #2
0
 async def _request_fee_estimates(self, interface):
     session = interface.session
     from .simple_config import FEE_ETA_TARGETS
     self.config.requested_fee_estimates()
     async with TaskGroup() as group:
         histogram_task = await group.spawn(session.send_request('mempool.get_fee_histogram'))
         fee_tasks = []
         for i in FEE_ETA_TARGETS:
             fee_tasks.append((i, await group.spawn(session.send_request('blockchain.estimatefee', [i]))))
     self.config.mempool_fees = histogram = histogram_task.result()
     if histogram == []:
         # [] is bad :\ It's dummy data.
         self.config.mempool_fees = histogram = [[11, 10000000], [11, 10000000], [11, 10000000], [11, 10000000], [11, 10000000],
                                                 [11, 10000000], [11, 10000000], [11, 10000000], [11, 10000000], [11, 10000000]]
     self.print_error('fee_histogram', histogram)
     self.notify('fee_histogram')
     for i, task in fee_tasks:
         fee = int(task.result() * COIN)
         self.print_error("fee_estimates[%d]" % i, fee)
         if fee < 0: continue
         self.config.update_fee_estimates(i, fee)
     self.notify('fee')
Пример #3
0
    async def fetch_and_process_blocks(self, caught_up_event):
        '''Fetch, process and index blocks from the daemon.

        Sets caught_up_event when first caught up.  Flushes to disk
        and shuts down cleanly if cancelled.

        This is mainly because if, during initial sync ElectrumX is
        asked to shut down when a large number of blocks have been
        processed but not written to disk, it should write those to
        disk before exiting, as otherwise a significant amount of work
        could be lost.
        '''
        self._caught_up_event = caught_up_event
        await self._first_open_dbs()
        try:
            async with TaskGroup() as group:
                await group.spawn(self.prefetcher.main_loop(self.height))
                await group.spawn(self._process_prefetched_blocks())
        finally:
            # Shut down block processing
            self.logger.info('flushing to DB for a clean shutdown...')
            await self.flush(True)
Пример #4
0
    async def _request_missing_txs(self,
                                   hist,
                                   *,
                                   allow_server_not_finding_tx=False):
        # "hist" is a list of [tx_hash, tx_height] lists
        transaction_hashes = []
        for tx_hash, tx_height in hist:
            if tx_hash in self.requested_tx:
                continue
            if self.wallet.db.get_transaction(tx_hash):
                continue
            transaction_hashes.append(tx_hash)
            self.requested_tx[tx_hash] = tx_height

        if not transaction_hashes: return
        async with TaskGroup() as group:
            for tx_hash in transaction_hashes:
                await group.spawn(
                    self._get_transaction(
                        tx_hash,
                        allow_server_not_finding_tx=allow_server_not_finding_tx
                    ))
Пример #5
0
 async def serve(self, notifications, event):
     '''Start the RPC server if enabled.  When the event is triggered,
     start TCP and SSL servers.'''
     try:
         if self.env.rpc_port is not None:
             await self._start_server('RPC', self.env.cs_host(for_rpc=True),
                                      self.env.rpc_port)
         await event.wait()
         self.logger.info(f'max session count: {self.env.max_sessions:,d}')
         self.logger.info(f'session timeout: '
                          f'{self.env.session_timeout:,d} seconds')
         self.logger.info('session bandwidth limit {:,d} bytes'
                          .format(self.env.bandwidth_limit))
         self.logger.info('max response size {:,d} bytes'
                          .format(self.env.max_send))
         self.logger.info('max subscriptions across all sessions: {:,d}'
                          .format(self.max_subs))
         self.logger.info('max subscriptions per session: {:,d}'
                          .format(self.env.max_session_subs))
         if self.env.drop_client is not None:
             self.logger.info('drop clients matching: {}'
                              .format(self.env.drop_client.pattern))
         # Start notifications; initialize hsub_results
         await notifications.start(self.db.db_height, self._notify_sessions)
         await self._start_external_servers()
         # Peer discovery should start after the external servers
         # because we connect to ourself
         async with TaskGroup() as group:
             await group.spawn(self.peer_mgr.discover_peers())
             await group.spawn(self._clear_stale_sessions())
             await group.spawn(self._log_sessions())
             await group.spawn(self._manage_servers())
     finally:
         # Close servers then sessions
         await self._close_servers(list(self.servers.keys()))
         for session in list(self.sessions):
             await session.spawn(session.close(force_after=1))
         for session in list(self.sessions):
             await session.closed_event.wait()
Пример #6
0
 async def _request_fee_estimates(self, interface):
     session = interface.session
     from .simple_config import FEE_ETA_TARGETS
     self.config.requested_fee_estimates()
     async with TaskGroup() as group:
         histogram_task = await group.spawn(
             session.send_request('mempool.get_fee_histogram'))
         fee_tasks = []
         for i in FEE_ETA_TARGETS:
             fee_tasks.append((i, await group.spawn(
                 session.send_request('blockchain.estimatefee', [i]))))
     self.config.mempool_fees = histogram = histogram_task.result()
     self.print_error(f'fee_histogram {histogram}')
     self.notify('fee_histogram')
     fee_estimates_eta = {}
     for nblock_target, task in fee_tasks:
         fee = int(task.result() * COIN)
         fee_estimates_eta[nblock_target] = fee
         if fee < 0: continue
         self.config.update_fee_estimates(nblock_target, fee)
     self.print_error(f'fee_estimates {fee_estimates_eta}')
     self.notify('fee')
Пример #7
0
 async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence):
     num_connecting = len(self.connecting)
     for server in servers:
         self._start_interface(server)
     # sleep a bit
     for _ in range(10):
         if len(self.connecting) < num_connecting:
             break
         await asyncio.sleep(1)
     responses = dict()
     async def get_response(iface: Interface):
         try:
             res = await iface.session.send_request(method, params, timeout=10)
         except Exception as e:
             res = e
         responses[iface.server] = res
     async with TaskGroup() as group:
         for server in servers:
             interface = self.interfaces.get(server)
             if interface:
                 await group.spawn(get_response(interface))
     return responses
Пример #8
0
    def __init__(self, env, db):
        self.logger = class_logger(__name__, self.__class__.__name__)
        # Initialise the Peer class
        Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
        self.env = env
        self.db = db

        # Our clearnet and Tor Peers, if any
        sclass = env.coin.SESSIONCLS
        self.myselves = [
            Peer(ident.host, sclass.server_features(env), 'env')
            for ident in env.identities
        ]
        self.server_version_args = sclass.server_version_args()
        # Peers have one entry per hostname.  Once connected, the
        # ip_addr property is either None, an onion peer, or the
        # IP address that was connected to.  Adding a peer will evict
        # any other peers with the same host name or IP address.
        self.peers = set()
        self.permit_onion_peer_time = time.time()
        self.proxy = None
        self.group = TaskGroup()
    async def serve(self, shutdown_event):
        '''Start the RPC server and wait for the mempool to synchronize.  Then
        start serving external clients.
        '''
        reqd_version = (0, 5, 9)
        if aiorpcx_version != reqd_version:
            raise RuntimeError('ElectrumX requires aiorpcX version '
                               f'{version_string(reqd_version)}')

        env = self.env
        min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
        self.logger.info(f'software version: {electrumx.version}')
        self.logger.info(f'aiorpcX version: {version_string(aiorpcx_version)}')
        self.logger.info(f'supported protocol versions: {min_str}-{max_str}')
        self.logger.info(f'event loop policy: {env.loop_policy}')
        self.logger.info(f'reorg limit is {env.reorg_limit:,d} blocks')

        notifications = Notifications()
        daemon = env.coin.DAEMON(env)
        BlockProcessor = env.coin.BLOCK_PROCESSOR
        MemPool = env.coin.MEM_POOL
        bp = BlockProcessor(env, daemon, notifications)
        mempool = MemPool(env.coin, daemon, notifications, bp.lookup_utxos)
        chain_state = ChainState(env, daemon, bp, notifications)
        session_mgr = SessionManager(env, chain_state, mempool, notifications,
                                     shutdown_event)

        caught_up_event = Event()
        serve_externally_event = Event()
        synchronized_event = Event()

        async with TaskGroup() as group:
            await group.spawn(session_mgr.serve(serve_externally_event))
            await group.spawn(bp.fetch_and_process_blocks(caught_up_event))
            await caught_up_event.wait()
            await group.spawn(mempool.keep_synchronized(synchronized_event))
            await synchronized_event.wait()
            serve_externally_event.set()
Пример #10
0
    async def _verify_peer(self, session, peer):
        if not peer.is_tor:
            address = session.peer_address()
            if address:
                peer.ip_addr = address[0]

        # server.version goes first
        message = 'server.version'
        result = await session.send_request(message, self.server_version_args)
        assert_good(message, result, list)

        # Protocol version 1.1 returns a pair with the version first
        if len(result) != 2 or not all(isinstance(x, str) for x in result):
            raise BadPeerError(f'bad server.version result: {result}')
        server_version, protocol_version = result
        peer.server_version = server_version
        peer.features['server_version'] = server_version
        ptuple = protocol_tuple(protocol_version)

        async with TaskGroup() as g:
            await g.spawn(self._send_headers_subscribe(session, peer, ptuple))
            await g.spawn(self._send_server_features(session, peer))
            await g.spawn(self._send_peers_subscribe(session, peer))
Пример #11
0
    async def _request_missing_txs(self, hist, *, allow_server_not_finding_tx=False):
        # "hist" is a list of [tx_hash, tx_height, tx_type: Optional[str]] lists
        transaction_hashes_and_types = []
        for item in hist:
            tx_hash = item[0]
            tx_height = item[1]
            # keep backward compatibility when fetch history from db where transaction type does not exist
            try:
                tx_type = TxType.from_str(item[2])
            except IndexError:
                tx_type = TxType.NONVAULT

            if tx_hash in self.requested_tx:
                continue
            if self.wallet.db.get_transaction(tx_hash):
                continue
            transaction_hashes_and_types.append((tx_hash, tx_type))
            self.requested_tx[tx_hash] = tx_height

        if not transaction_hashes_and_types: return
        async with TaskGroup() as group:
            for tx_hash, tx_type in transaction_hashes_and_types:
                await group.spawn(self._get_transaction(tx_hash, tx_type=tx_type, allow_server_not_finding_tx=allow_server_not_finding_tx))
Пример #12
0
    async def _verify_peer(self, session, peer):
        if self._is_blacklisted(peer.host):
            raise BadPeerError('blacklisted')

        if not peer.is_tor:
            address = session.peer_address()
            if address:
                peer.ip_addr = address[0]

        # server.version goes first
        message = 'server.version'
        result = await session.send_request(message, self.server_version_args)
        assert_good(message, result, list)

        # Protocol version 1.1 returns a pair with the version first
        if len(result) != 2 or not all(isinstance(x, str) for x in result):
            raise BadPeerError(f'bad server.version result: {result}')
        server_version, protocol_version = result
        peer.server_version = server_version
        peer.features['server_version'] = server_version
        ptuple = protocol_tuple(protocol_version)

        async with TaskGroup() as g:
            await g.spawn(self._send_headers_subscribe(session, peer, ptuple))
            await g.spawn(self._send_server_features(session, peer))
            peers_task = await g.spawn(self._send_peers_subscribe
                                       (session, peer))

        # Process reported peers if remote peer is good
        peers = peers_task.result()
        await self._note_peers(peers)

        features = self._features_to_register(peer, peers)
        if features:
            self.logger.info(f'registering ourself with {peer}')
            # We only care to wait for the response
            await session.send_request('server.add_peer', [features])
Пример #13
0
 def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
     self.name = name
     Logger.__init__(self)
     NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
     self.node_keypair = local_keypair
     self.network = MockNetwork(tx_queue)
     self.taskgroup = TaskGroup()
     self.lnwatcher = None
     self.listen_server = None
     self._channels = {chan.channel_id: chan for chan in chans}
     self.payments = {}
     self.logs = defaultdict(list)
     self.wallet = MockWallet()
     self.features = LnFeatures(0)
     self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
     self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
     self.features |= LnFeatures.VAR_ONION_OPT
     self.features |= LnFeatures.PAYMENT_SECRET_OPT
     self.features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT
     self.pending_payments = defaultdict(asyncio.Future)
     for chan in chans:
         chan.lnworker = self
     self._peers = {}  # bytes -> Peer
     # used in tests
     self.enable_htlc_settle = asyncio.Event()
     self.enable_htlc_settle.set()
     self.enable_htlc_forwarding = asyncio.Event()
     self.enable_htlc_forwarding.set()
     self.received_mpp_htlcs = dict()
     self.sent_htlcs = defaultdict(asyncio.Queue)
     self.sent_htlcs_routes = dict()
     self.sent_buckets = defaultdict(set)
     self.trampoline_forwarding_failures = {}
     self.inflight_payments = set()
     self.preimages = {}
     self.stopping_soon = False
    async def fetch_and_process_blocks(self, caught_up_event):
        '''Fetch, process and index blocks from the daemon.

        Sets caught_up_event when first caught up.  Flushes to disk
        and shuts down cleanly if cancelled.

        This is mainly because if, during initial sync ElectrumX is
        asked to shut down when a large number of blocks have been
        processed but not written to disk, it should write those to
        disk before exiting, as otherwise a significant amount of work
        could be lost.
        '''
        self._caught_up_event = caught_up_event
        await self._first_open_dbs()
        try:
            async with TaskGroup() as group:
                await group.spawn(self.prefetcher.main_loop(self.height))
                await group.spawn(self._process_blocks())
        # Don't flush for arbitrary exceptions as they might be a cause or consequence of
        # corrupted data
        except CancelledError:
            self.logger.info('flushing to DB for a clean shutdown...')
            await self.run_with_lock(self.flush(True))
            self.logger.info('flushed cleanly')
Пример #15
0
async def test_unordered_UTXOs():
    api = API()
    api.initialize()
    mempool = MemPool(coin, api)
    event = Event()
    async with TaskGroup() as group:
        await group.spawn(mempool.keep_synchronized, event)
        await event.wait()
        await group.cancel_remaining()

    # Check the default dict is handled properly
    prior_len = len(mempool.hashXs)
    assert await mempool.unordered_UTXOs(os.urandom(HASHX_LEN)) == []
    assert prior_len == len(mempool.hashXs)

    # Test all hashXs
    utxos = api.UTXOs()
    for hashX in api.hashXs:
        mempool_result = await mempool.unordered_UTXOs(hashX)
        our_result = utxos.get(hashX, [])
        assert set(our_result) == {
            dataclasses.astuple(mr)
            for mr in mempool_result
        }
Пример #16
0
 async def keep_synchronized(self, synchronized_event):
     '''Keep the mempool synchronized with the daemon.'''
     async with TaskGroup(wait=any) as group:
         await group.spawn(self._refresh_hashes(synchronized_event))
         await group.spawn(self._refresh_histogram(synchronized_event))
         await group.spawn(self._logging(synchronized_event))
Пример #17
0
class Daemon(Logger):

    network: Optional[Network]

    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.running = False
        self.running_lock = threading.Lock()
        self.config = config
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        self.asyncio_loop = asyncio.get_event_loop()
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        daemon_jobs = []
        # Setup commands server
        self.commands_server = None
        if listen_jsonrpc:
            self.commands_server = CommandsServer(self, fd)
            daemon_jobs.append(self.commands_server.run())
        # pay server
        self.pay_server = None
        payserver_address = self.config.get_netaddress('payserver_address')
        if not config.get('offline') and payserver_address:
            self.pay_server = PayServer(self, payserver_address)
            daemon_jobs.append(self.pay_server.run())
        # server-side watchtower
        self.watchtower = None
        watchtower_address = self.config.get_netaddress('watchtower_address')
        if not config.get('offline') and watchtower_address:
            self.watchtower = WatchTowerServer(self.network,
                                               watchtower_address)
            daemon_jobs.append(self.watchtower.run)
        if self.network:
            self.network.start(jobs=[self.fx.run])

        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs),
                                         self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        self.logger.info("starting taskgroup.")
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait
                                  )  # run forever (until cancel)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.logger.exception("taskgroup died.")
        finally:
            self.logger.info("taskgroup stopped.")

    def load_wallet(self,
                    path,
                    password,
                    *,
                    manual_upgrades=True) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if not wallet:
            return False
        wallet.stop()
        return True

    def run_daemon(self):
        self.running = True
        try:
            while self.is_running():
                time.sleep(0.1)
        except KeyboardInterrupt:
            self.running = False
        self.on_stop()

    def is_running(self):
        with self.running_lock:
            return self.running and not self.taskgroup.closed()

    def stop(self):
        with self.running_lock:
            self.running = False

    def on_stop(self):
        if self.gui_object:
            self.gui_object.stop()
        # stop network/wallets
        for k, wallet in self._wallets.items():
            wallet.stop()
        if self.network:
            self.logger.info("shutting down network")
            self.network.stop()
        self.logger.info("stopping taskgroup")
        fut = asyncio.run_coroutine_threadsafe(
            self.taskgroup.cancel_remaining(), self.asyncio_loop)
        try:
            fut.result(timeout=2)
        except (concurrent.futures.TimeoutError,
                concurrent.futures.CancelledError, asyncio.CancelledError):
            pass
        self.logger.info("removing lockfile")
        remove_lockfile(get_lockfile(self.config))
        self.logger.info("stopped")

    def run_gui(self, config, plugins):
        threading.current_thread().setName('GUI')
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrum.gui.' + gui_name, fromlist=['electrum'])
            self.gui_object = gui.ElectrumGui(config, self, plugins)
            self.gui_object.main()
        except BaseException as e:
            self.logger.error(
                f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            self.on_stop()
Пример #18
0
class Daemon(Logger):
    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.auth_lock = asyncio.Lock()
        self.running = False
        self.running_lock = threading.Lock()
        self.config = config
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        self.asyncio_loop = asyncio.get_event_loop()
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        daemon_jobs = []
        # Setup JSONRPC server
        if listen_jsonrpc:
            daemon_jobs.append(self.start_jsonrpc(config, fd))
        # request server
        self.pay_server = None
        if not config.get('offline') and self.config.get('run_payserver'):
            self.pay_server = PayServer(self)
            daemon_jobs.append(self.pay_server.run())
        # server-side watchtower
        self.watchtower = None
        if not config.get('offline') and self.config.get('run_watchtower'):
            self.watchtower = WatchTowerServer(self.network)
            daemon_jobs.append(self.watchtower.run)
        if self.network:
            self.network.start(jobs=[self.fx.run])

        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs),
                                         self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait
                                  )  # run forever (until cancel)
        except BaseException as e:
            self.logger.exception('daemon.taskgroup died.')
        finally:
            self.logger.info("stopping daemon.taskgroup")

    async def authenticate(self, headers):
        if self.rpc_password == '':
            # RPC authentication is disabled
            return
        auth_string = headers.get('Authorization', None)
        if auth_string is None:
            raise AuthenticationInvalidOrMissing('CredentialsMissing')
        basic, _, encoded = auth_string.partition(' ')
        if basic != 'Basic':
            raise AuthenticationInvalidOrMissing('UnsupportedType')
        encoded = to_bytes(encoded, 'utf8')
        credentials = to_string(b64decode(encoded), 'utf8')
        username, _, password = credentials.partition(':')
        if not (constant_time_compare(username, self.rpc_user)
                and constant_time_compare(password, self.rpc_password)):
            await asyncio.sleep(0.050)
            raise AuthenticationCredentialsInvalid('Invalid Credentials')

    async def handle(self, request):
        async with self.auth_lock:
            try:
                await self.authenticate(request.headers)
            except AuthenticationInvalidOrMissing:
                return web.Response(
                    headers={"WWW-Authenticate": "Basic realm=Electrum"},
                    text='Unauthorized',
                    status=401)
            except AuthenticationCredentialsInvalid:
                return web.Response(text='Forbidden', status=403)
        request = await request.text()
        response = await jsonrpcserver.async_dispatch(request,
                                                      methods=self.methods)
        if isinstance(response, jsonrpcserver.response.ExceptionResponse):
            self.logger.error(f"error handling request: {request}",
                              exc_info=response.exc)
            # this exposes the error message to the client
            response.message = str(response.exc)
        if response.wanted:
            return web.json_response(response.deserialized(),
                                     status=response.http_status)
        else:
            return web.Response()

    async def start_jsonrpc(self, config: SimpleConfig, fd):
        self.app = web.Application()
        self.app.router.add_post("/", self.handle)
        self.rpc_user, self.rpc_password = get_rpc_credentials(config)
        self.methods = jsonrpcserver.methods.Methods()
        self.methods.add(self.ping)
        self.methods.add(self.gui)
        self.cmd_runner = Commands(config=self.config,
                                   network=self.network,
                                   daemon=self)
        for cmdname in known_commands:
            self.methods.add(getattr(self.cmd_runner, cmdname))
        self.methods.add(self.run_cmdline)
        self.host = config.get('rpchost', '127.0.0.1')
        self.port = config.get('rpcport', 0)
        self.runner = web.AppRunner(self.app)
        await self.runner.setup()
        site = web.TCPSite(self.runner, self.host, self.port)
        await site.start()
        socket = site._server.sockets[0]
        os.write(fd, bytes(repr((socket.getsockname(), time.time())), 'utf8'))
        os.close(fd)

    async def ping(self):
        return True

    async def gui(self, config_options):
        if self.gui_object:
            if hasattr(self.gui_object, 'new_window'):
                path = self.config.get_wallet_path(use_gui_last_wallet=True)
                self.gui_object.new_window(path, config_options.get('url'))
                response = "ok"
            else:
                response = "error: current GUI does not support multiple windows"
        else:
            response = "Error: Electrum is running in daemon mode. Please stop the daemon first."
        return response

    def load_wallet(self,
                    path,
                    password,
                    *,
                    manual_upgrades=True) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        self.wallet = wallet
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Abstract_Wallet:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if not wallet:
            return False
        wallet.stop_threads()
        return True

    async def run_cmdline(self, config_options):
        cmdname = config_options['cmd']
        cmd = known_commands[cmdname]
        # arguments passed to function
        args = [config_options.get(x) for x in cmd.params]
        # decode json arguments
        args = [json_decode(i) for i in args]
        # options
        kwargs = {}
        for x in cmd.options:
            kwargs[x] = config_options.get(x)
        if cmd.requires_wallet:
            kwargs['wallet_path'] = config_options.get('wallet_path')
        func = getattr(self.cmd_runner, cmd.name)
        # fixme: not sure how to retrieve message in jsonrpcclient
        try:
            result = await func(*args, **kwargs)
        except Exception as e:
            result = {'error': str(e)}
        return result

    def run_daemon(self):
        self.running = True
        try:
            while self.is_running():
                time.sleep(0.1)
        except KeyboardInterrupt:
            self.running = False
        self.on_stop()

    def is_running(self):
        with self.running_lock:
            return self.running and not self.taskgroup.closed()

    def stop(self):
        with self.running_lock:
            self.running = False

    def on_stop(self):
        if self.gui_object:
            self.gui_object.stop()
        # stop network/wallets
        for k, wallet in self._wallets.items():
            wallet.stop_threads()
        if self.network:
            self.logger.info("shutting down network")
            self.network.stop()
        self.logger.info("stopping taskgroup")
        fut = asyncio.run_coroutine_threadsafe(
            self.taskgroup.cancel_remaining(), self.asyncio_loop)
        try:
            fut.result(timeout=2)
        except (asyncio.TimeoutError, asyncio.CancelledError):
            pass
        self.logger.info("removing lockfile")
        remove_lockfile(get_lockfile(self.config))
        self.logger.info("stopped")

    def run_gui(self, config, plugins):
        threading.current_thread().setName('GUI')
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrum.gui.' + gui_name, fromlist=['electrum'])
            self.gui_object = gui.ElectrumGui(config, self, plugins)
            self.gui_object.main()
        except BaseException as e:
            self.logger.error(
                f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            self.on_stop()
Пример #19
0
    async def _process_mempool(self, all_hashes, touched, mempool_height):
        # Re-sync with the new set of hashes
        txs = self.txs
        hashXs = self.hashXs

        tx_to_create = self.tx_to_asset_create
        tx_to_reissue = self.tx_to_asset_reissue
        creates = self.asset_creates
        reissues = self.asset_reissues

        if mempool_height != self.api.db_height():
            raise DBSyncError

        # First handle txs that have disappeared
        for tx_hash in set(txs).difference(all_hashes):
            tx = txs.pop(tx_hash)

            reissued_asset = tx_to_reissue.pop(tx_hash, None)
            if reissued_asset:
                del reissues[reissued_asset]

            created_asset = tx_to_create.pop(tx_hash, None)
            if created_asset:
                del creates[created_asset]

            tx_hashXs = set(hashX for hashX, value, _, _ in tx.in_pairs)
            tx_hashXs.update(hashX for hashX, value, _, _ in tx.out_pairs)
            for hashX in tx_hashXs:
                hashXs[hashX].remove(tx_hash)
                if not hashXs[hashX]:
                    del hashXs[hashX]
            touched.update(tx_hashXs)

        # Process new transactions
        new_hashes = list(all_hashes.difference(txs))
        if new_hashes:
            group = TaskGroup()
            for hashes in chunks(new_hashes, 200):
                coro = self._fetch_and_accept(hashes, all_hashes, touched)
                await group.spawn(coro)

            tx_map = {}
            utxo_map = {}
            async for task in group:
                (deferred,
                 unspent), internal_creates, internal_reissues = task.result()

                # Store asset changes
                for asset, stats in internal_creates.items():
                    tx_to_create[hex_str_to_hash(
                        stats['source']['tx_hash'])] = asset
                    creates[asset] = stats

                for asset, stats in internal_reissues.items():
                    tx_to_reissue[hex_str_to_hash(
                        stats['source']['tx_hash'])] = asset
                    reissues[asset] = stats

                tx_map.update(deferred)
                utxo_map.update(unspent)

            prior_count = 0
            # FIXME: this is not particularly efficient
            while tx_map and len(tx_map) != prior_count:
                prior_count = len(tx_map)
                tx_map, utxo_map = self._accept_transactions(
                    tx_map, utxo_map, touched)
            if tx_map:
                self.logger.error(f'{len(tx_map)} txs dropped')

        return touched
Пример #20
0
 async def cleanup_lnworkers():
     async with TaskGroup() as group:
         for lnworker in self._lnworkers_created:
             await group.spawn(lnworker.stop())
     self._lnworkers_created.clear()
Пример #21
0
    async def _process_mempool(self, all_hashes, touched, mempool_height):
        # Re-sync with the new set of hashes
        txs = self.txs
        hashXs = self.hashXs
        nameXs = self.nameXs

        if mempool_height != self.api.db_height():
            raise DBSyncError

        # First handle txs that have disappeared
        for tx_hash in set(txs).difference(all_hashes):
            tx = txs.pop(tx_hash)

            # If it is a name transaction, remove the index.
            if tx.nameout:
                if nameXs[tx.nameout]:
                    nameXs[tx.nameout].remove(tx_hash)
                if not nameXs[tx.nameout]:
                    del nameXs[tx.nameout]

            # Kevacoin specific
            if tx.keyout:
                if nameXs[tx.keyout]:
                    nameXs[tx.keyout].remove(tx_hash)
                if not nameXs[tx.keyout]:
                    del nameXs[tx.keyout]

            if tx.namekeyout:
                if nameXs[tx.namekeyout]:
                    nameXs[tx.namekeyout].remove(tx_hash)
                if not nameXs[tx.namekeyout]:
                    del nameXs[tx.namekeyout]

            tx_hashXs = set(hashX for hashX, value in tx.in_pairs)
            tx_hashXs.update(hashX for hashX, value in tx.out_pairs)
            for hashX in tx_hashXs:
                hashXs[hashX].remove(tx_hash)
                if not hashXs[hashX]:
                    del hashXs[hashX]

            touched.update(tx_hashXs)

        # Process new transactions
        new_hashes = list(all_hashes.difference(txs))
        if new_hashes:
            group = TaskGroup()
            for hashes in chunks(new_hashes, 200):
                coro = self._fetch_and_accept(hashes, all_hashes, touched)
                await group.spawn(coro)
            if mempool_height != self.api.db_height():
                raise DBSyncError

            tx_map = {}
            utxo_map = {}
            async for task in group:
                deferred, unspent = task.result()
                tx_map.update(deferred)
                utxo_map.update(unspent)

            prior_count = 0
            # FIXME: this is not particularly efficient
            while tx_map and len(tx_map) != prior_count:
                prior_count = len(tx_map)
                tx_map, utxo_map = self._accept_transactions(
                    tx_map, utxo_map, touched)
            if tx_map:
                self.logger.error(f'{len(tx_map)} txs dropped')

        return touched
Пример #22
0
    async def _verify_peer(self, session, peer):
        # store IP address for peer
        if not peer.is_tor:
            address = session.peer_address()
            if address:
                peer.ip_addr = address[0]

        if self._is_blacklisted(peer):
            raise BadPeerError('blacklisted')

        # Bucket good recent peers; forbid many servers from similar IPs
        # FIXME there's a race here, when verifying multiple peers
        #       that belong to the same bucket ~simultaneously
        recent_peers = self._get_recent_good_peers()
        if peer in recent_peers:
            recent_peers.remove(peer)
        onion_peers = []
        buckets = defaultdict(list)
        for other_peer in recent_peers:
            if other_peer.is_tor:
                onion_peers.append(other_peer)
            else:
                buckets[other_peer.bucket_for_internal_purposes()].append(other_peer)
        if peer.is_tor:
            # keep number of onion peers below half of all peers,
            # but up to 100 is OK regardless
            if len(onion_peers) > len(recent_peers) // 2 >= 100:
                raise BadPeerError('too many onion peers already')
        else:
            bucket = peer.bucket_for_internal_purposes()
            if len(buckets[bucket]) > 0:
                raise BadPeerError(f'too many peers already in bucket {bucket}')

        # server.version goes first
        message = 'server.version'
        result = await session.send_request(message, self.server_version_args)
        assert_good(message, result, list)

        # Protocol version 1.1 returns a pair with the version first
        if len(result) != 2 or not all(isinstance(x, str) for x in result):
            raise BadPeerError(f'bad server.version result: {result}')
        server_version, protocol_version = result
        peer.server_version = server_version
        peer.features['server_version'] = server_version
        ptuple = protocol_tuple(protocol_version)

        async with TaskGroup() as g:
            await g.spawn(self._send_headers_subscribe(session, peer, ptuple))
            await g.spawn(self._send_server_features(session, peer))
            peers_task = await g.spawn(self._send_peers_subscribe
                                       (session, peer))

        # Process reported peers if remote peer is good
        peers = peers_task.result()
        await self._note_peers(peers)

        features = self._features_to_register(peer, peers)
        if features:
            self.logger.info(f'registering ourself with {peer}')
            # We only care to wait for the response
            await session.send_request('server.add_peer', [features])
Пример #23
0
class Daemon(Logger):

    network: Optional[Network]
    gui_object: Optional[Union['gui.qt.ElectrumGui', 'gui.kivy.ElectrumGui']]

    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.config = config
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        if 'wallet_path' in config.cmdline_options:
            self.logger.warning("Ignoring parameter 'wallet_path' for daemon. "
                                "Use the load_wallet command instead.")
        self.asyncio_loop = asyncio.get_event_loop()
        asyncio_wait_time = 0
        while not self.asyncio_loop.is_running():
            if asyncio_wait_time > 30:
                raise Exception('event loop not running for 30 seconds')
            time.sleep(0.1)
            asyncio_wait_time += 0.1
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        self.current_wallet_path = None
        daemon_jobs = []
        # Setup commands server
        self.commands_server = None
        if listen_jsonrpc:
            self.commands_server = CommandsServer(self, fd)
            daemon_jobs.append(self.commands_server.run())
        # pay server
        self.pay_server = None
        payserver_address = self.config.get_netaddress('payserver_address')
        if not config.get('offline') and payserver_address:
            self.pay_server = PayServer(self, payserver_address)
            daemon_jobs.append(self.pay_server.run())
        if self.network:
            self.network.start(jobs=[self.fx.run])

        self.stopping_soon = threading.Event()
        self.stopped_event = asyncio.Event()
        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs),
                                         self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        self.logger.info("starting taskgroup.")
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait
                                  )  # run forever (until cancel)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.logger.exception("taskgroup died.")
        finally:
            self.logger.info("taskgroup stopped.")
            self.stopping_soon.set()

    def load_wallet(self,
                    path,
                    password,
                    *,
                    manual_upgrades=True,
                    set_current=False) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            if set_current:
                self.current_wallet_path = path
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.upgrade_done:
            storage.backup_old_version()
        if getattr(storage, 'backup_message', None):
            log_backup_msg = ' '.join(storage.backup_message.split('\n'))
            self.logger.info(f'{log_backup_msg}')
            if not self.config.get('detach'):
                util.print_stderr(f'{storage.backup_message}\n')
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        if db.check_unfinished_multisig():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        if set_current:
            self.current_wallet_path = path
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if self.current_wallet_path == path:
            self.current_wallet_path = None
        if not wallet:
            return False
        fut = asyncio.run_coroutine_threadsafe(wallet.stop(),
                                               self.asyncio_loop)
        fut.result()
        return True

    def run_daemon(self):
        try:
            self.stopping_soon.wait()
        except KeyboardInterrupt:
            self.stopping_soon.set()
        self.on_stop()

    async def stop(self):
        self.stopping_soon.set()
        await self.stopped_event.wait()

    def on_stop(self):
        try:
            self.logger.info("on_stop() entered. initiating shutdown")
            if self.gui_object:
                self.gui_object.stop()

            async def stop_async():
                self.logger.info("stopping all wallets")
                async with TaskGroup() as group:
                    for k, wallet in self._wallets.items():
                        await group.spawn(wallet.stop())
                self.logger.info("stopping network and taskgroup")
                async with ignore_after(2):
                    async with TaskGroup() as group:
                        if self.network:
                            await group.spawn(
                                self.network.stop(full_shutdown=True))
                        await group.spawn(self.taskgroup.cancel_remaining())

            fut = asyncio.run_coroutine_threadsafe(stop_async(),
                                                   self.asyncio_loop)
            fut.result()
        finally:
            self.logger.info("removing lockfile")
            remove_lockfile(get_lockfile(self.config))
            self.logger.info("stopped")
            self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set)

    def run_gui(self, config, plugins):
        threading.current_thread().setName('GUI')
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrum_zcash.gui.' + gui_name,
                             fromlist=['electrum_zcash'])
            self.gui_object = gui.ElectrumGui(config, self, plugins)
            self.gui_object.main()
        except BaseException as e:
            self.logger.error(
                f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            self.on_stop()
Пример #24
0
class Network(PrintError):
    """The Network class manages a set of connections to remote electrum
    servers, each connected socket is handled by an Interface() object.
    Connections are initiated by a Connection() thread which stops once
    the connection succeeds or fails.

    Our external API:

    - Member functions get_header(), get_interfaces(), get_local_height(),
          get_parameters(), get_server_height(), get_status_value(),
          is_connected(), set_parameters(), stop()
    """
    verbosity_filter = 'n'

    def __init__(self, config=None):
        global INSTANCE
        INSTANCE = self
        if config is None:
            config = {}  # Do not use mutables as default values!
        self.config = SimpleConfig(config) if isinstance(config,
                                                         dict) else config
        self.num_server = 10 if not self.config.get('oneserver') else 0
        blockchain.blockchains = blockchain.read_blockchains(
            self.config)  # note: needs self.blockchains_lock
        self.print_error("blockchains", list(blockchain.blockchains.keys()))
        self.blockchain_index = config.get('blockchain_index', 0)
        if self.blockchain_index not in blockchain.blockchains.keys():
            self.blockchain_index = 0
        # Server for addresses and transactions
        self.default_server = self.config.get('server', None)
        # Sanitize default server
        if self.default_server:
            try:
                deserialize_server(self.default_server)
            except:
                self.print_error(
                    'Warning: failed to parse server-string; falling back to random.'
                )
                self.default_server = None
        if not self.default_server:
            self.default_server = pick_random_server()

        # locks: if you need to take multiple ones, acquire them in the order they are defined here!
        self.bhi_lock = asyncio.Lock()
        self.interface_lock = threading.RLock()  # <- re-entrant
        self.callback_lock = threading.Lock()
        self.recent_servers_lock = threading.RLock()  # <- re-entrant
        self.blockchains_lock = threading.Lock()

        self.irc_servers = {}  # returned by interface (list from irc)
        self.recent_servers = self.read_recent_servers(
        )  # note: needs self.recent_servers_lock

        self.banner = ''
        self.donation_address = ''
        self.relay_fee = None
        # callbacks set by the GUI
        self.callbacks = defaultdict(list)  # note: needs self.callback_lock

        dir_path = os.path.join(self.config.path, 'certs')
        util.make_dir(dir_path)

        # retry times
        self.server_retry_time = time.time()
        self.nodes_retry_time = time.time()
        # kick off the network.  interface is the main server we are currently
        # communicating with.  interfaces is the set of servers we are connecting
        # to or have an ongoing connection with
        self.interface = None  # note: needs self.interface_lock
        self.interfaces = {}  # note: needs self.interface_lock
        self.auto_connect = self.config.get('auto_connect', True)
        self.connecting = set()
        self.requested_chunks = set()
        self.socket_queue = queue.Queue()
        self.start_network(
            deserialize_server(self.default_server)[2],
            deserialize_proxy(self.config.get('proxy')))
        self.asyncio_loop = asyncio.get_event_loop()

    @staticmethod
    def get_instance():
        return INSTANCE

    def with_interface_lock(func):
        def func_wrapper(self, *args, **kwargs):
            with self.interface_lock:
                return func(self, *args, **kwargs)

        return func_wrapper

    def with_recent_servers_lock(func):
        def func_wrapper(self, *args, **kwargs):
            with self.recent_servers_lock:
                return func(self, *args, **kwargs)

        return func_wrapper

    def register_callback(self, callback, events):
        with self.callback_lock:
            for event in events:
                self.callbacks[event].append(callback)

    def unregister_callback(self, callback):
        with self.callback_lock:
            for callbacks in self.callbacks.values():
                if callback in callbacks:
                    callbacks.remove(callback)

    def trigger_callback(self, event, *args):
        with self.callback_lock:
            callbacks = self.callbacks[event][:]
        for callback in callbacks:
            if asyncio.iscoroutinefunction(callback):
                # FIXME: if callback throws, we will lose the traceback
                asyncio.run_coroutine_threadsafe(callback(event, *args),
                                                 self.asyncio_loop)
            else:
                callback(event, *args)

    def read_recent_servers(self):
        if not self.config.path:
            return []
        path = os.path.join(self.config.path, "recent_servers")
        try:
            with open(path, "r", encoding='utf-8') as f:
                data = f.read()
                return json.loads(data)
        except:
            return []

    @with_recent_servers_lock
    def save_recent_servers(self):
        if not self.config.path:
            return
        path = os.path.join(self.config.path, "recent_servers")
        s = json.dumps(self.recent_servers, indent=4, sort_keys=True)
        try:
            with open(path, "w", encoding='utf-8') as f:
                f.write(s)
        except:
            pass

    @with_interface_lock
    def get_server_height(self):
        return self.interface.tip if self.interface else 0

    def server_is_lagging(self):
        sh = self.get_server_height()
        if not sh:
            self.print_error('no height for main interface')
            return True
        lh = self.get_local_height()
        result = (lh - sh) > 1
        if result:
            self.print_error('%s is lagging (%d vs %d)' %
                             (self.default_server, sh, lh))
        return result

    def set_status(self, status):
        self.connection_status = status
        self.notify('status')

    def is_connected(self):
        return self.interface is not None and self.interface.ready.done()

    def is_connecting(self):
        return self.connection_status == 'connecting'

    async def request_server_info(self, interface):
        await interface.ready
        session = interface.session
        self.banner = await session.send_request('server.banner')
        self.notify('banner')
        self.donation_address = await session.send_request(
            'server.donation_address')
        self.irc_servers = parse_servers(
            await session.send_request('server.peers.subscribe'))
        self.notify('servers')
        await self.request_fee_estimates(interface)
        relayfee = await session.send_request('blockchain.relayfee')
        self.relay_fee = int(relayfee * COIN) if relayfee is not None else None

    async def request_fee_estimates(self, interface):
        session = interface.session
        from .simple_config import FEE_ETA_TARGETS
        self.config.requested_fee_estimates()
        async with TaskGroup() as group:
            histogram_task = await group.spawn(
                session.send_request('mempool.get_fee_histogram'))
            fee_tasks = []
            for i in FEE_ETA_TARGETS:
                fee_tasks.append((i, await group.spawn(
                    session.send_request('blockchain.estimatefee', [i]))))
        self.config.mempool_fees = histogram_task.result()
        self.notify('fee_histogram')
        for i, task in fee_tasks:
            fee = int(task.result() * COIN)
            self.config.update_fee_estimates(i, fee)
            self.print_error("fee_estimates[%d]" % i, fee)
        self.notify('fee')

    def get_status_value(self, key):
        if key == 'status':
            value = self.connection_status
        elif key == 'banner':
            value = self.banner
        elif key == 'fee':
            value = self.config.fee_estimates
        elif key == 'fee_histogram':
            value = self.config.mempool_fees
        elif key == 'updated':
            value = (self.get_local_height(), self.get_server_height())
        elif key == 'servers':
            value = self.get_servers()
        elif key == 'interfaces':
            value = self.get_interfaces()
        return value

    def notify(self, key):
        if key in ['status', 'updated']:
            self.trigger_callback(key)
        else:
            self.trigger_callback(key, self.get_status_value(key))

    def get_parameters(self) -> NetworkParameters:
        host, port, protocol = deserialize_server(self.default_server)
        return NetworkParameters(host, port, protocol, self.proxy,
                                 self.auto_connect)

    def get_donation_address(self):
        if self.is_connected():
            return self.donation_address

    @with_interface_lock
    def get_interfaces(self):
        '''The interfaces that are in connected state'''
        return list(self.interfaces.keys())

    @with_recent_servers_lock
    def get_servers(self):
        out = constants.net.DEFAULT_SERVERS
        if self.irc_servers:
            out.update(filter_version(self.irc_servers.copy()))
        else:
            for s in self.recent_servers:
                try:
                    host, port, protocol = deserialize_server(s)
                except:
                    continue
                if host not in out:
                    out[host] = {protocol: port}
        if self.config.get('noonion'):
            out = filter_noonion(out)
        return out

    @with_interface_lock
    def start_interface(self, server):
        if (not server in self.interfaces and not server in self.connecting):
            if server == self.default_server:
                self.print_error("connecting to %s as new interface" % server)
                self.set_status('connecting')
            self.connecting.add(server)
            self.socket_queue.put(server)

    def start_random_interface(self):
        with self.interface_lock:
            exclude_set = self.disconnected_servers.union(set(self.interfaces))
        server = pick_random_server(self.get_servers(), self.protocol,
                                    exclude_set)
        if server:
            self.start_interface(server)
        return server

    def set_proxy(self, proxy: Optional[dict]):
        self.proxy = proxy
        # Store these somewhere so we can un-monkey-patch
        if not hasattr(socket, "_getaddrinfo"):
            socket._getaddrinfo = socket.getaddrinfo
        if proxy:
            self.print_error('setting proxy', proxy)
            # prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy
            socket.getaddrinfo = lambda *args: [(
                socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))]
        else:
            if sys.platform == 'win32':
                # On Windows, socket.getaddrinfo takes a mutex, and might hold it for up to 10 seconds
                # when dns-resolving. To speed it up drastically, we resolve dns ourselves, outside that lock.
                # see #4421
                socket.getaddrinfo = self._fast_getaddrinfo
            else:
                socket.getaddrinfo = socket._getaddrinfo
        self.trigger_callback('proxy_set', self.proxy)

    @staticmethod
    def _fast_getaddrinfo(host, *args, **kwargs):
        def needs_dns_resolving(host2):
            try:
                ipaddress.ip_address(host2)
                return False  # already valid IP
            except ValueError:
                pass  # not an IP
            if str(host) in (
                    'localhost',
                    'localhost.',
            ):
                return False
            return True

        try:
            if needs_dns_resolving(host):
                answers = dns.resolver.query(host)
                addr = str(answers[0])
            else:
                addr = host
        except dns.exception.DNSException:
            # dns failed for some reason, e.g. dns.resolver.NXDOMAIN
            # this is normal. Simply report back failure:
            raise socket.gaierror(11001, 'getaddrinfo failed')
        except BaseException as e:
            # Possibly internal error in dnspython :( see #4483
            # Fall back to original socket.getaddrinfo to resolve dns.
            print_error('dnspython failed to resolve dns with error:', e)
            addr = host
        return socket._getaddrinfo(addr, *args, **kwargs)

    @with_interface_lock
    def start_network(self, protocol: str, proxy: Optional[dict]):
        assert not self.interface and not self.interfaces
        assert not self.connecting and self.socket_queue.empty()
        self.print_error('starting network')
        self.disconnected_servers = set([])  # note: needs self.interface_lock
        self.protocol = protocol
        self.set_proxy(proxy)
        self.start_interface(self.default_server)

    @with_interface_lock
    def stop_network(self):
        self.print_error("stopping network")
        for interface in list(self.interfaces.values()):
            self.close_interface(interface)
        if self.interface:
            self.close_interface(self.interface)
        assert self.interface is None
        assert not self.interfaces
        self.connecting.clear()
        # Get a new queue - no old pending connections thanks!
        self.socket_queue = queue.Queue()

    def set_parameters(self, net_params: NetworkParameters):
        proxy = net_params.proxy
        proxy_str = serialize_proxy(proxy)
        host, port, protocol = net_params.host, net_params.port, net_params.protocol
        server_str = serialize_server(host, port, protocol)
        # sanitize parameters
        try:
            deserialize_server(serialize_server(host, port, protocol))
            if proxy:
                proxy_modes.index(proxy["mode"]) + 1
                int(proxy['port'])
        except:
            return
        self.config.set_key('auto_connect', net_params.auto_connect, False)
        self.config.set_key("proxy", proxy_str, False)
        self.config.set_key("server", server_str, True)
        # abort if changes were not allowed by config
        if self.config.get('server') != server_str or self.config.get(
                'proxy') != proxy_str:
            return
        self.auto_connect = net_params.auto_connect
        if self.proxy != proxy or self.protocol != protocol:
            # Restart the network defaulting to the given server
            with self.interface_lock:
                self.stop_network()
                self.default_server = server_str
                self.start_network(protocol, proxy)
        elif self.default_server != server_str:
            self.switch_to_interface(server_str)
        else:
            self.switch_lagging_interface()
            self.notify('updated')

    def switch_to_random_interface(self):
        '''Switch to a random connected server other than the current one'''
        servers = self.get_interfaces()  # Those in connected state
        if self.default_server in servers:
            servers.remove(self.default_server)
        if servers:
            self.switch_to_interface(random.choice(servers))

    @with_interface_lock
    def switch_lagging_interface(self):
        '''If auto_connect and lagging, switch interface'''
        if self.server_is_lagging() and self.auto_connect:
            # switch to one that has the correct header (not height)
            header = self.blockchain().read_header(self.get_local_height())

            def filt(x):
                a = x[1].tip_header
                b = header
                assert type(a) is type(b)
                return a == b

            filtered = list(
                map(lambda x: x[0], filter(filt, self.interfaces.items())))
            if filtered:
                choice = random.choice(filtered)
                self.switch_to_interface(choice)

    @with_interface_lock
    def switch_to_interface(self, server):
        '''Switch to server as our interface.  If no connection exists nor
        being opened, start a thread to connect.  The actual switch will
        happen on receipt of the connection notification.  Do nothing
        if server already is our interface.'''
        self.default_server = server
        if server not in self.interfaces:
            self.interface = None
            self.start_interface(server)
            return

        i = self.interfaces[server]
        if self.interface != i:
            self.print_error("switching to", server)
            if self.interface is not None:
                # Stop any current interface in order to terminate subscriptions,
                # and to cancel tasks in interface.group.
                # However, for headers sub, give preference to this interface
                # over unknown ones, i.e. start it again right away.
                old_server = self.interface.server
                self.close_interface(self.interface)
                if len(self.interfaces) <= self.num_server:
                    self.start_interface(old_server)

            self.interface = i
            asyncio.get_event_loop().create_task(
                i.group.spawn(self.request_server_info(i)))
            self.trigger_callback('default_server_changed')
            self.set_status('connected')
            self.notify('updated')
            self.notify('interfaces')

    @with_interface_lock
    def close_interface(self, interface):
        if interface:
            if interface.server in self.interfaces:
                self.interfaces.pop(interface.server)
            if interface.server == self.default_server:
                self.interface = None
            interface.close()

    @with_recent_servers_lock
    def add_recent_server(self, server):
        # list is ordered
        if server in self.recent_servers:
            self.recent_servers.remove(server)
        self.recent_servers.insert(0, server)
        self.recent_servers = self.recent_servers[0:20]
        self.save_recent_servers()

    @with_interface_lock
    def connection_down(self, server):
        '''A connection to server either went down, or was never made.
        We distinguish by whether it is in self.interfaces.'''
        self.disconnected_servers.add(server)
        if server == self.default_server:
            self.set_status('disconnected')
        if server in self.interfaces:
            self.close_interface(self.interfaces[server])
            self.notify('interfaces')

    @aiosafe
    async def new_interface(self, server):
        # todo: get tip first, then decide which checkpoint to use.
        self.add_recent_server(server)

        interface = Interface(self, server, self.config.path, self.proxy)
        timeout = 10 if not self.proxy else 20
        try:
            await asyncio.wait_for(interface.ready, timeout)
        except BaseException as e:
            #import traceback
            #traceback.print_exc()
            self.print_error(interface.server, "couldn't launch because",
                             str(e), str(type(e)))
            self.connection_down(interface.server)
            return
        finally:
            try:
                self.connecting.remove(server)
            except KeyError:
                pass

        with self.interface_lock:
            self.interfaces[server] = interface

        if server == self.default_server:
            self.switch_to_interface(server)

        self.notify('interfaces')

    def init_headers_file(self):
        b = blockchain.blockchains[0]
        filename = b.path()
        length = 80 * len(constants.net.CHECKPOINTS) * 2016
        if not os.path.exists(filename) or os.path.getsize(filename) < length:
            with open(filename, 'wb') as f:
                if length > 0:
                    f.seek(length - 1)
                    f.write(b'\x00')
        with b.lock:
            b.update_size()

    async def get_merkle_for_transaction(self, tx_hash, tx_height):
        return await self.interface.session.send_request(
            'blockchain.transaction.get_merkle', [tx_hash, tx_height])

    def broadcast_transaction_from_non_network_thread(self, tx, timeout=10):
        # note: calling this from the network thread will deadlock it
        fut = asyncio.run_coroutine_threadsafe(
            self.broadcast_transaction(tx, timeout=timeout), self.asyncio_loop)
        return fut.result()

    async def broadcast_transaction(self, tx, timeout=10):
        try:
            out = await self.interface.session.send_request(
                'blockchain.transaction.broadcast', [str(tx)], timeout=timeout)
        except asyncio.TimeoutError as e:
            return False, "error: operation timed out"
        except Exception as e:
            return False, "error: " + str(e)

        if out != tx.txid():
            return False, "error: " + out
        return True, out

    async def request_chunk(self,
                            height,
                            tip,
                            session=None,
                            can_return_early=False):
        if session is None: session = self.interface.session
        index = height // 2016
        if can_return_early and index in self.requested_chunks:
            return
        size = 2016
        if tip is not None:
            size = min(size, tip - index * 2016)
            size = max(size, 0)
        try:
            self.requested_chunks.add(index)
            res = await session.send_request('blockchain.block.headers',
                                             [index * 2016, size])
        finally:
            try:
                self.requested_chunks.remove(index)
            except KeyError:
                pass
        conn = self.blockchain().connect_chunk(index, res['hex'])
        if not conn:
            return conn, 0
        return conn, res['count']

    @with_interface_lock
    def blockchain(self):
        if self.interface and self.interface.blockchain is not None:
            self.blockchain_index = self.interface.blockchain.forkpoint
        return blockchain.blockchains[self.blockchain_index]

    @with_interface_lock
    def get_blockchains(self):
        out = {}
        with self.blockchains_lock:
            blockchain_items = list(blockchain.blockchains.items())
        for k, b in blockchain_items:
            r = list(
                filter(lambda i: i.blockchain == b,
                       list(self.interfaces.values())))
            if r:
                out[k] = r
        return out

    def follow_chain(self, index):
        bc = blockchain.blockchains.get(index)
        if bc:
            self.blockchain_index = index
            self.config.set_key('blockchain_index', index)
            with self.interface_lock:
                interfaces = list(self.interfaces.values())
            for i in interfaces:
                if i.blockchain == bc:
                    self.switch_to_interface(i.server)
                    break
        else:
            raise Exception('blockchain not found', index)

        with self.interface_lock:
            if self.interface:
                net_params = self.get_parameters()
                host, port, protocol = deserialize_server(
                    self.interface.server)
                net_params = net_params._replace(host=host,
                                                 port=port,
                                                 protocol=protocol)
                self.set_parameters(net_params)

    def get_local_height(self):
        return self.blockchain().height()

    def export_checkpoints(self, path):
        # run manually from the console to generate checkpoints
        cp = self.blockchain().get_checkpoints()
        with open(path, 'w', encoding='utf-8') as f:
            f.write(json.dumps(cp, indent=4))

    def start(self, fx=None):
        self.main_taskgroup = TaskGroup()

        async def main():
            self.init_headers_file()
            async with self.main_taskgroup as group:
                await group.spawn(self.maintain_sessions())
                if fx: await group.spawn(fx)

        self._wrapper_thread = threading.Thread(
            target=self.asyncio_loop.run_until_complete, args=(main(), ))
        self._wrapper_thread.start()

    def stop(self):
        asyncio.run_coroutine_threadsafe(
            self.main_taskgroup.cancel_remaining(), self.asyncio_loop)

    def join(self):
        self._wrapper_thread.join(1)

    async def maintain_sessions(self):
        while True:
            while self.socket_queue.qsize() > 0:
                server = self.socket_queue.get()
                asyncio.get_event_loop().create_task(
                    self.new_interface(server))
            remove = []
            for k, i in self.interfaces.items():
                if i.fut.done() and not i.exception:
                    assert False, "interface future should not finish without exception"
                if i.exception:
                    if not i.fut.done():
                        try:
                            i.fut.cancel()
                        except Exception as e:
                            self.print_error('exception while cancelling fut',
                                             e)
                    try:
                        raise i.exception
                    except BaseException as e:
                        self.print_error(i.server, "errored because:", str(e),
                                         str(type(e)))
                    remove.append(k)
            for k in remove:
                self.connection_down(k)

            # nodes
            now = time.time()
            for i in range(self.num_server - len(self.interfaces) -
                           len(self.connecting)):
                self.start_random_interface()
            if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
                self.print_error('network: retrying connections')
                self.disconnected_servers = set([])
                self.nodes_retry_time = now

            # main interface
            if not self.is_connected():
                if self.auto_connect:
                    if not self.is_connecting():
                        self.switch_to_random_interface()
                else:
                    if self.default_server in self.disconnected_servers:
                        if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
                            self.disconnected_servers.remove(
                                self.default_server)
                            self.server_retry_time = now
                    else:
                        self.switch_to_interface(self.default_server)
            else:
                if self.config.is_fee_estimates_update_required():
                    await self.interface.group.spawn(
                        self.request_fee_estimates(self.interface))

            await asyncio.sleep(0.1)
Пример #25
0
    async def _process_mempool(
        self,
        *,
        all_hashes: Set[bytes],  # set of txids
        touched_hashxs: Set[bytes],  # set of hashXs
        touched_outpoints: Set[Tuple[bytes, int]],  # set of outpoints
        mempool_height: int,
    ) -> None:
        # Re-sync with the new set of hashes
        txs = self.txs
        hashXs = self.hashXs
        txo_to_spender = self.txo_to_spender

        if mempool_height != self.api.db_height():
            raise DBSyncError

        # First handle txs that have disappeared
        for tx_hash in (set(txs) - all_hashes):
            tx = txs.pop(tx_hash)
            # hashXs
            tx_hashXs = {hashX for hashX, value in tx.in_pairs}
            tx_hashXs.update(hashX for hashX, value in tx.out_pairs)
            for hashX in tx_hashXs:
                hashXs[hashX].remove(tx_hash)
                if not hashXs[hashX]:
                    del hashXs[hashX]
            touched_hashxs |= tx_hashXs
            # outpoints
            for prevout in tx.prevouts:
                del txo_to_spender[prevout]
                touched_outpoints.add(prevout)
            for out_idx, out_pair in enumerate(tx.out_pairs):
                touched_outpoints.add((tx_hash, out_idx))

        # Process new transactions
        new_hashes = list(all_hashes.difference(txs))
        if new_hashes:
            group = TaskGroup()
            for hashes in chunks(new_hashes, 200):
                coro = self._fetch_and_accept(
                    hashes=hashes,
                    all_hashes=all_hashes,
                    touched_hashxs=touched_hashxs,
                    touched_outpoints=touched_outpoints,
                )
                await group.spawn(coro)
            if mempool_height != self.api.db_height():
                raise DBSyncError

            tx_map = {}
            utxo_map = {}
            async for task in group:
                deferred, unspent = task.result()
                tx_map.update(deferred)
                utxo_map.update(unspent)

            prior_count = 0
            # FIXME: this is not particularly efficient
            while tx_map and len(tx_map) != prior_count:
                prior_count = len(tx_map)
                tx_map, utxo_map = self._accept_transactions(
                    tx_map=tx_map,
                    utxo_map=utxo_map,
                    touched_hashxs=touched_hashxs,
                    touched_outpoints=touched_outpoints,
                )
            if tx_map:
                self.logger.error(f'{len(tx_map)} txs dropped')
Пример #26
0
class Daemon(Logger):

    network: Optional[Network]
    gui_object: Optional['gui.BaseElectrumSysGui']

    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.config = config
        self.listen_jsonrpc = listen_jsonrpc
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        if 'wallet_path' in config.cmdline_options:
            self.logger.warning("Ignoring parameter 'wallet_path' for daemon. "
                                "Use the load_wallet command instead.")
        self.asyncio_loop = asyncio.get_event_loop()
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        daemon_jobs = []
        # Setup commands server
        self.commands_server = None
        if listen_jsonrpc:
            self.commands_server = CommandsServer(self, fd)
            daemon_jobs.append(self.commands_server.run())
        # pay server
        self.pay_server = None
        payserver_address = self.config.get_netaddress('payserver_address')
        if not config.get('offline') and payserver_address:
            self.pay_server = PayServer(self, payserver_address)
            daemon_jobs.append(self.pay_server.run())
        # server-side watchtower
        self.watchtower = None
        watchtower_address = self.config.get_netaddress('watchtower_address')
        if not config.get('offline') and watchtower_address:
            self.watchtower = WatchTowerServer(self.network, watchtower_address)
            daemon_jobs.append(self.watchtower.run)
        if self.network:
            self.network.start(jobs=[self.fx.run])
            # prepare lightning functionality, also load channel db early
            if self.config.get('use_gossip', False):
                self.network.start_gossip()

        self._stop_entered = False
        self._stopping_soon_or_errored = threading.Event()
        self._stopped_event = threading.Event()
        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs), self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        self.logger.info("starting taskgroup.")
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait)  # run forever (until cancel)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.logger.exception("taskgroup died.")
            util.send_exception_to_crash_reporter(e)
        finally:
            self.logger.info("taskgroup stopped.")
            # note: we could just "await self.stop()", but in that case GUI users would
            #       not see the exception (especially if the GUI did not start yet).
            self._stopping_soon_or_errored.set()

    def load_wallet(self, path, password, *, manual_upgrades=True) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        # note: this must not be called from the event loop. # TODO raise if so
        fut = asyncio.run_coroutine_threadsafe(self._stop_wallet(path), self.asyncio_loop)
        return fut.result()

    async def _stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if not wallet:
            return False
        await wallet.stop()
        return True

    def run_daemon(self):
        try:
            self._stopping_soon_or_errored.wait()
        except KeyboardInterrupt:
            asyncio.run_coroutine_threadsafe(self.stop(), self.asyncio_loop).result()
        self._stopped_event.wait()

    async def stop(self):
        if self._stop_entered:
            return
        self._stop_entered = True
        self._stopping_soon_or_errored.set()
        self.logger.info("stop() entered. initiating shutdown")
        try:
            if self.gui_object:
                self.gui_object.stop()
            self.logger.info("stopping all wallets")
            async with TaskGroup() as group:
                for k, wallet in self._wallets.items():
                    await group.spawn(wallet.stop())
            self.logger.info("stopping network and taskgroup")
            async with ignore_after(2):
                async with TaskGroup() as group:
                    if self.network:
                        await group.spawn(self.network.stop(full_shutdown=True))
                    await group.spawn(self.taskgroup.cancel_remaining())
        finally:
            if self.listen_jsonrpc:
                self.logger.info("removing lockfile")
                remove_lockfile(get_lockfile(self.config))
            self.logger.info("stopped")
            self._stopped_event.set()

    def run_gui(self, config, plugins):
        threading.current_thread().name = 'GUI'
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrumsys.gui.' + gui_name, fromlist=['electrumsys'])
            self.gui_object = gui.ElectrumSysGui(config=config, daemon=self, plugins=plugins)
            if not self._stop_entered:
                self.gui_object.main()
            else:
                # If daemon.stop() was called before gui_object got created, stop gui now.
                self.gui_object.stop()
        except BaseException as e:
            self.logger.error(f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            asyncio.run_coroutine_threadsafe(self.stop(), self.asyncio_loop).result()
Пример #27
0
 async def many_payments():
     async with TaskGroup() as group:
         for pay_req in pay_reqs2:
             await group.spawn(single_payment(pay_req))
     gath.cancel()
Пример #28
0
async def test_notifications(caplog):
    # Tests notifications over a cycle of:
    # 1) A first batch of txs come in
    # 2) A second batch of txs come in
    # 3) A block comes in confirming the first batch only
    api = API()
    api.initialize()
    mempool = MemPool(coin, api, refresh_secs=0.001, log_status_secs=0)
    event = Event()

    n = len(api.ordered_adds) // 2
    raw_txs = api.raw_txs.copy()
    txs = api.txs.copy()
    first_hashes = api.ordered_adds[:n]
    first_touched = api.touched(first_hashes)
    second_hashes = api.ordered_adds[n:]
    second_touched = api.touched(second_hashes)

    caplog.set_level(logging.DEBUG)

    async with TaskGroup() as group:
        # First batch enters the mempool
        api.raw_txs = {hash: raw_txs[hash] for hash in first_hashes}
        api.txs = {hash: txs[hash] for hash in first_hashes}
        first_utxos = api.mempool_utxos()
        first_spends = api.mempool_spends()
        await group.spawn(mempool.keep_synchronized, event)
        await event.wait()
        assert len(api.on_mempool_calls) == 1
        touched, height = api.on_mempool_calls[0]
        assert height == api._height == api._db_height == api._cached_height
        assert touched == first_touched
        # Second batch enters the mempool
        api.raw_txs = raw_txs
        api.txs = txs
        await event.wait()
        assert len(api.on_mempool_calls) == 2
        touched, height = api.on_mempool_calls[1]
        assert height == api._height == api._db_height == api._cached_height
        # Touched is incremental
        assert touched == second_touched
        # Block found; first half confirm
        new_height = 2
        api._height = new_height
        api.raw_txs = {hash: raw_txs[hash] for hash in second_hashes}
        api.txs = {hash: txs[hash] for hash in second_hashes}
        # Delay the DB update
        assert not in_caplog(caplog, 'waiting for DB to sync')
        async with ignore_after(max(mempool.refresh_secs * 2, 0.5)):
            await event.wait()
        assert in_caplog(caplog, 'waiting for DB to sync')
        assert len(api.on_mempool_calls) == 2
        assert not event.is_set()
        assert api._height == api._cached_height == new_height
        assert touched == second_touched
        # Now update the DB
        api.db_utxos.update(first_utxos)
        api._db_height = new_height
        for spend in first_spends:
            del api.db_utxos[spend]
        await event.wait()
        assert len(api.on_mempool_calls) == 3
        touched, height = api.on_mempool_calls[2]
        assert height == api._db_height == new_height
        assert touched == first_touched
        await group.cancel_remaining()