示例#1
0
class Daemon(Logger):

    network: Optional[Network]

    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.running = False
        self.running_lock = threading.Lock()
        self.config = config
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        self.asyncio_loop = asyncio.get_event_loop()
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        daemon_jobs = []
        # Setup commands server
        self.commands_server = None
        if listen_jsonrpc:
            self.commands_server = CommandsServer(self, fd)
            daemon_jobs.append(self.commands_server.run())
        # pay server
        self.pay_server = None
        payserver_address = self.config.get_netaddress('payserver_address')
        if not config.get('offline') and payserver_address:
            self.pay_server = PayServer(self, payserver_address)
            daemon_jobs.append(self.pay_server.run())
        # server-side watchtower
        self.watchtower = None
        watchtower_address = self.config.get_netaddress('watchtower_address')
        if not config.get('offline') and watchtower_address:
            self.watchtower = WatchTowerServer(self.network,
                                               watchtower_address)
            daemon_jobs.append(self.watchtower.run)
        if self.network:
            self.network.start(jobs=[self.fx.run])

        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs),
                                         self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        self.logger.info("starting taskgroup.")
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait
                                  )  # run forever (until cancel)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.logger.exception("taskgroup died.")
        finally:
            self.logger.info("taskgroup stopped.")

    def load_wallet(self,
                    path,
                    password,
                    *,
                    manual_upgrades=True) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if not wallet:
            return False
        wallet.stop()
        return True

    def run_daemon(self):
        self.running = True
        try:
            while self.is_running():
                time.sleep(0.1)
        except KeyboardInterrupt:
            self.running = False
        self.on_stop()

    def is_running(self):
        with self.running_lock:
            return self.running and not self.taskgroup.closed()

    def stop(self):
        with self.running_lock:
            self.running = False

    def on_stop(self):
        if self.gui_object:
            self.gui_object.stop()
        # stop network/wallets
        for k, wallet in self._wallets.items():
            wallet.stop()
        if self.network:
            self.logger.info("shutting down network")
            self.network.stop()
        self.logger.info("stopping taskgroup")
        fut = asyncio.run_coroutine_threadsafe(
            self.taskgroup.cancel_remaining(), self.asyncio_loop)
        try:
            fut.result(timeout=2)
        except (concurrent.futures.TimeoutError,
                concurrent.futures.CancelledError, asyncio.CancelledError):
            pass
        self.logger.info("removing lockfile")
        remove_lockfile(get_lockfile(self.config))
        self.logger.info("stopped")

    def run_gui(self, config, plugins):
        threading.current_thread().setName('GUI')
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrum.gui.' + gui_name, fromlist=['electrum'])
            self.gui_object = gui.ElectrumGui(config, self, plugins)
            self.gui_object.main()
        except BaseException as e:
            self.logger.error(
                f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            self.on_stop()
示例#2
0
class Daemon(Logger):
    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.auth_lock = asyncio.Lock()
        self.running = False
        self.running_lock = threading.Lock()
        self.config = config
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        self.asyncio_loop = asyncio.get_event_loop()
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        daemon_jobs = []
        # Setup JSONRPC server
        if listen_jsonrpc:
            daemon_jobs.append(self.start_jsonrpc(config, fd))
        # request server
        self.pay_server = None
        if not config.get('offline') and self.config.get('run_payserver'):
            self.pay_server = PayServer(self)
            daemon_jobs.append(self.pay_server.run())
        # server-side watchtower
        self.watchtower = None
        if not config.get('offline') and self.config.get('run_watchtower'):
            self.watchtower = WatchTowerServer(self.network)
            daemon_jobs.append(self.watchtower.run)
        if self.network:
            self.network.start(jobs=[self.fx.run])

        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs),
                                         self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait
                                  )  # run forever (until cancel)
        except BaseException as e:
            self.logger.exception('daemon.taskgroup died.')
        finally:
            self.logger.info("stopping daemon.taskgroup")

    async def authenticate(self, headers):
        if self.rpc_password == '':
            # RPC authentication is disabled
            return
        auth_string = headers.get('Authorization', None)
        if auth_string is None:
            raise AuthenticationInvalidOrMissing('CredentialsMissing')
        basic, _, encoded = auth_string.partition(' ')
        if basic != 'Basic':
            raise AuthenticationInvalidOrMissing('UnsupportedType')
        encoded = to_bytes(encoded, 'utf8')
        credentials = to_string(b64decode(encoded), 'utf8')
        username, _, password = credentials.partition(':')
        if not (constant_time_compare(username, self.rpc_user)
                and constant_time_compare(password, self.rpc_password)):
            await asyncio.sleep(0.050)
            raise AuthenticationCredentialsInvalid('Invalid Credentials')

    async def handle(self, request):
        async with self.auth_lock:
            try:
                await self.authenticate(request.headers)
            except AuthenticationInvalidOrMissing:
                return web.Response(
                    headers={"WWW-Authenticate": "Basic realm=Electrum"},
                    text='Unauthorized',
                    status=401)
            except AuthenticationCredentialsInvalid:
                return web.Response(text='Forbidden', status=403)
        request = await request.text()
        response = await jsonrpcserver.async_dispatch(request,
                                                      methods=self.methods)
        if isinstance(response, jsonrpcserver.response.ExceptionResponse):
            self.logger.error(f"error handling request: {request}",
                              exc_info=response.exc)
            # this exposes the error message to the client
            response.message = str(response.exc)
        if response.wanted:
            return web.json_response(response.deserialized(),
                                     status=response.http_status)
        else:
            return web.Response()

    async def start_jsonrpc(self, config: SimpleConfig, fd):
        self.app = web.Application()
        self.app.router.add_post("/", self.handle)
        self.rpc_user, self.rpc_password = get_rpc_credentials(config)
        self.methods = jsonrpcserver.methods.Methods()
        self.methods.add(self.ping)
        self.methods.add(self.gui)
        self.cmd_runner = Commands(config=self.config,
                                   network=self.network,
                                   daemon=self)
        for cmdname in known_commands:
            self.methods.add(getattr(self.cmd_runner, cmdname))
        self.methods.add(self.run_cmdline)
        self.host = config.get('rpchost', '127.0.0.1')
        self.port = config.get('rpcport', 0)
        self.runner = web.AppRunner(self.app)
        await self.runner.setup()
        site = web.TCPSite(self.runner, self.host, self.port)
        await site.start()
        socket = site._server.sockets[0]
        os.write(fd, bytes(repr((socket.getsockname(), time.time())), 'utf8'))
        os.close(fd)

    async def ping(self):
        return True

    async def gui(self, config_options):
        if self.gui_object:
            if hasattr(self.gui_object, 'new_window'):
                path = self.config.get_wallet_path(use_gui_last_wallet=True)
                self.gui_object.new_window(path, config_options.get('url'))
                response = "ok"
            else:
                response = "error: current GUI does not support multiple windows"
        else:
            response = "Error: Electrum is running in daemon mode. Please stop the daemon first."
        return response

    def load_wallet(self,
                    path,
                    password,
                    *,
                    manual_upgrades=True) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        self.wallet = wallet
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Abstract_Wallet:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if not wallet:
            return False
        wallet.stop_threads()
        return True

    async def run_cmdline(self, config_options):
        cmdname = config_options['cmd']
        cmd = known_commands[cmdname]
        # arguments passed to function
        args = [config_options.get(x) for x in cmd.params]
        # decode json arguments
        args = [json_decode(i) for i in args]
        # options
        kwargs = {}
        for x in cmd.options:
            kwargs[x] = config_options.get(x)
        if cmd.requires_wallet:
            kwargs['wallet_path'] = config_options.get('wallet_path')
        func = getattr(self.cmd_runner, cmd.name)
        # fixme: not sure how to retrieve message in jsonrpcclient
        try:
            result = await func(*args, **kwargs)
        except Exception as e:
            result = {'error': str(e)}
        return result

    def run_daemon(self):
        self.running = True
        try:
            while self.is_running():
                time.sleep(0.1)
        except KeyboardInterrupt:
            self.running = False
        self.on_stop()

    def is_running(self):
        with self.running_lock:
            return self.running and not self.taskgroup.closed()

    def stop(self):
        with self.running_lock:
            self.running = False

    def on_stop(self):
        if self.gui_object:
            self.gui_object.stop()
        # stop network/wallets
        for k, wallet in self._wallets.items():
            wallet.stop_threads()
        if self.network:
            self.logger.info("shutting down network")
            self.network.stop()
        self.logger.info("stopping taskgroup")
        fut = asyncio.run_coroutine_threadsafe(
            self.taskgroup.cancel_remaining(), self.asyncio_loop)
        try:
            fut.result(timeout=2)
        except (asyncio.TimeoutError, asyncio.CancelledError):
            pass
        self.logger.info("removing lockfile")
        remove_lockfile(get_lockfile(self.config))
        self.logger.info("stopped")

    def run_gui(self, config, plugins):
        threading.current_thread().setName('GUI')
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrum.gui.' + gui_name, fromlist=['electrum'])
            self.gui_object = gui.ElectrumGui(config, self, plugins)
            self.gui_object.main()
        except BaseException as e:
            self.logger.error(
                f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            self.on_stop()
示例#3
0
class Network(PrintError):
    """The Network class manages a set of connections to remote electrum
    servers, each connected socket is handled by an Interface() object.
    Connections are initiated by a Connection() thread which stops once
    the connection succeeds or fails.

    Our external API:

    - Member functions get_header(), get_interfaces(), get_local_height(),
          get_parameters(), get_server_height(), get_status_value(),
          is_connected(), set_parameters(), stop()
    """
    verbosity_filter = 'n'

    def __init__(self, config=None):
        global INSTANCE
        INSTANCE = self
        if config is None:
            config = {}  # Do not use mutables as default values!
        self.config = SimpleConfig(config) if isinstance(config,
                                                         dict) else config
        self.num_server = 10 if not self.config.get('oneserver') else 0
        blockchain.blockchains = blockchain.read_blockchains(
            self.config)  # note: needs self.blockchains_lock
        self.print_error("blockchains", list(blockchain.blockchains.keys()))
        self.blockchain_index = config.get('blockchain_index', 0)
        if self.blockchain_index not in blockchain.blockchains.keys():
            self.blockchain_index = 0
        # Server for addresses and transactions
        self.default_server = self.config.get('server', None)
        # Sanitize default server
        if self.default_server:
            try:
                deserialize_server(self.default_server)
            except:
                self.print_error(
                    'Warning: failed to parse server-string; falling back to random.'
                )
                self.default_server = None
        if not self.default_server:
            self.default_server = pick_random_server()

        # locks: if you need to take multiple ones, acquire them in the order they are defined here!
        self.bhi_lock = asyncio.Lock()
        self.interface_lock = threading.RLock()  # <- re-entrant
        self.callback_lock = threading.Lock()
        self.recent_servers_lock = threading.RLock()  # <- re-entrant
        self.blockchains_lock = threading.Lock()

        self.irc_servers = {}  # returned by interface (list from irc)
        self.recent_servers = self.read_recent_servers(
        )  # note: needs self.recent_servers_lock

        self.banner = ''
        self.donation_address = ''
        self.relay_fee = None
        # callbacks set by the GUI
        self.callbacks = defaultdict(list)  # note: needs self.callback_lock

        dir_path = os.path.join(self.config.path, 'certs')
        util.make_dir(dir_path)

        # retry times
        self.server_retry_time = time.time()
        self.nodes_retry_time = time.time()
        # kick off the network.  interface is the main server we are currently
        # communicating with.  interfaces is the set of servers we are connecting
        # to or have an ongoing connection with
        self.interface = None  # note: needs self.interface_lock
        self.interfaces = {}  # note: needs self.interface_lock
        self.auto_connect = self.config.get('auto_connect', True)
        self.connecting = set()
        self.requested_chunks = set()
        self.socket_queue = queue.Queue()
        self.start_network(
            deserialize_server(self.default_server)[2],
            deserialize_proxy(self.config.get('proxy')))
        self.asyncio_loop = asyncio.get_event_loop()

    @staticmethod
    def get_instance():
        return INSTANCE

    def with_interface_lock(func):
        def func_wrapper(self, *args, **kwargs):
            with self.interface_lock:
                return func(self, *args, **kwargs)

        return func_wrapper

    def with_recent_servers_lock(func):
        def func_wrapper(self, *args, **kwargs):
            with self.recent_servers_lock:
                return func(self, *args, **kwargs)

        return func_wrapper

    def register_callback(self, callback, events):
        with self.callback_lock:
            for event in events:
                self.callbacks[event].append(callback)

    def unregister_callback(self, callback):
        with self.callback_lock:
            for callbacks in self.callbacks.values():
                if callback in callbacks:
                    callbacks.remove(callback)

    def trigger_callback(self, event, *args):
        with self.callback_lock:
            callbacks = self.callbacks[event][:]
        for callback in callbacks:
            if asyncio.iscoroutinefunction(callback):
                # FIXME: if callback throws, we will lose the traceback
                asyncio.run_coroutine_threadsafe(callback(event, *args),
                                                 self.asyncio_loop)
            else:
                callback(event, *args)

    def read_recent_servers(self):
        if not self.config.path:
            return []
        path = os.path.join(self.config.path, "recent_servers")
        try:
            with open(path, "r", encoding='utf-8') as f:
                data = f.read()
                return json.loads(data)
        except:
            return []

    @with_recent_servers_lock
    def save_recent_servers(self):
        if not self.config.path:
            return
        path = os.path.join(self.config.path, "recent_servers")
        s = json.dumps(self.recent_servers, indent=4, sort_keys=True)
        try:
            with open(path, "w", encoding='utf-8') as f:
                f.write(s)
        except:
            pass

    @with_interface_lock
    def get_server_height(self):
        return self.interface.tip if self.interface else 0

    def server_is_lagging(self):
        sh = self.get_server_height()
        if not sh:
            self.print_error('no height for main interface')
            return True
        lh = self.get_local_height()
        result = (lh - sh) > 1
        if result:
            self.print_error('%s is lagging (%d vs %d)' %
                             (self.default_server, sh, lh))
        return result

    def set_status(self, status):
        self.connection_status = status
        self.notify('status')

    def is_connected(self):
        return self.interface is not None and self.interface.ready.done()

    def is_connecting(self):
        return self.connection_status == 'connecting'

    async def request_server_info(self, interface):
        await interface.ready
        session = interface.session
        self.banner = await session.send_request('server.banner')
        self.notify('banner')
        self.donation_address = await session.send_request(
            'server.donation_address')
        self.irc_servers = parse_servers(
            await session.send_request('server.peers.subscribe'))
        self.notify('servers')
        await self.request_fee_estimates(interface)
        relayfee = await session.send_request('blockchain.relayfee')
        self.relay_fee = int(relayfee * COIN) if relayfee is not None else None

    async def request_fee_estimates(self, interface):
        session = interface.session
        from .simple_config import FEE_ETA_TARGETS
        self.config.requested_fee_estimates()
        async with TaskGroup() as group:
            histogram_task = await group.spawn(
                session.send_request('mempool.get_fee_histogram'))
            fee_tasks = []
            for i in FEE_ETA_TARGETS:
                fee_tasks.append((i, await group.spawn(
                    session.send_request('blockchain.estimatefee', [i]))))
        self.config.mempool_fees = histogram_task.result()
        self.notify('fee_histogram')
        for i, task in fee_tasks:
            fee = int(task.result() * COIN)
            self.config.update_fee_estimates(i, fee)
            self.print_error("fee_estimates[%d]" % i, fee)
        self.notify('fee')

    def get_status_value(self, key):
        if key == 'status':
            value = self.connection_status
        elif key == 'banner':
            value = self.banner
        elif key == 'fee':
            value = self.config.fee_estimates
        elif key == 'fee_histogram':
            value = self.config.mempool_fees
        elif key == 'updated':
            value = (self.get_local_height(), self.get_server_height())
        elif key == 'servers':
            value = self.get_servers()
        elif key == 'interfaces':
            value = self.get_interfaces()
        return value

    def notify(self, key):
        if key in ['status', 'updated']:
            self.trigger_callback(key)
        else:
            self.trigger_callback(key, self.get_status_value(key))

    def get_parameters(self) -> NetworkParameters:
        host, port, protocol = deserialize_server(self.default_server)
        return NetworkParameters(host, port, protocol, self.proxy,
                                 self.auto_connect)

    def get_donation_address(self):
        if self.is_connected():
            return self.donation_address

    @with_interface_lock
    def get_interfaces(self):
        '''The interfaces that are in connected state'''
        return list(self.interfaces.keys())

    @with_recent_servers_lock
    def get_servers(self):
        out = constants.net.DEFAULT_SERVERS
        if self.irc_servers:
            out.update(filter_version(self.irc_servers.copy()))
        else:
            for s in self.recent_servers:
                try:
                    host, port, protocol = deserialize_server(s)
                except:
                    continue
                if host not in out:
                    out[host] = {protocol: port}
        if self.config.get('noonion'):
            out = filter_noonion(out)
        return out

    @with_interface_lock
    def start_interface(self, server):
        if (not server in self.interfaces and not server in self.connecting):
            if server == self.default_server:
                self.print_error("connecting to %s as new interface" % server)
                self.set_status('connecting')
            self.connecting.add(server)
            self.socket_queue.put(server)

    def start_random_interface(self):
        with self.interface_lock:
            exclude_set = self.disconnected_servers.union(set(self.interfaces))
        server = pick_random_server(self.get_servers(), self.protocol,
                                    exclude_set)
        if server:
            self.start_interface(server)
        return server

    def set_proxy(self, proxy: Optional[dict]):
        self.proxy = proxy
        # Store these somewhere so we can un-monkey-patch
        if not hasattr(socket, "_getaddrinfo"):
            socket._getaddrinfo = socket.getaddrinfo
        if proxy:
            self.print_error('setting proxy', proxy)
            # prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy
            socket.getaddrinfo = lambda *args: [(
                socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))]
        else:
            if sys.platform == 'win32':
                # On Windows, socket.getaddrinfo takes a mutex, and might hold it for up to 10 seconds
                # when dns-resolving. To speed it up drastically, we resolve dns ourselves, outside that lock.
                # see #4421
                socket.getaddrinfo = self._fast_getaddrinfo
            else:
                socket.getaddrinfo = socket._getaddrinfo
        self.trigger_callback('proxy_set', self.proxy)

    @staticmethod
    def _fast_getaddrinfo(host, *args, **kwargs):
        def needs_dns_resolving(host2):
            try:
                ipaddress.ip_address(host2)
                return False  # already valid IP
            except ValueError:
                pass  # not an IP
            if str(host) in (
                    'localhost',
                    'localhost.',
            ):
                return False
            return True

        try:
            if needs_dns_resolving(host):
                answers = dns.resolver.query(host)
                addr = str(answers[0])
            else:
                addr = host
        except dns.exception.DNSException:
            # dns failed for some reason, e.g. dns.resolver.NXDOMAIN
            # this is normal. Simply report back failure:
            raise socket.gaierror(11001, 'getaddrinfo failed')
        except BaseException as e:
            # Possibly internal error in dnspython :( see #4483
            # Fall back to original socket.getaddrinfo to resolve dns.
            print_error('dnspython failed to resolve dns with error:', e)
            addr = host
        return socket._getaddrinfo(addr, *args, **kwargs)

    @with_interface_lock
    def start_network(self, protocol: str, proxy: Optional[dict]):
        assert not self.interface and not self.interfaces
        assert not self.connecting and self.socket_queue.empty()
        self.print_error('starting network')
        self.disconnected_servers = set([])  # note: needs self.interface_lock
        self.protocol = protocol
        self.set_proxy(proxy)
        self.start_interface(self.default_server)

    @with_interface_lock
    def stop_network(self):
        self.print_error("stopping network")
        for interface in list(self.interfaces.values()):
            self.close_interface(interface)
        if self.interface:
            self.close_interface(self.interface)
        assert self.interface is None
        assert not self.interfaces
        self.connecting.clear()
        # Get a new queue - no old pending connections thanks!
        self.socket_queue = queue.Queue()

    def set_parameters(self, net_params: NetworkParameters):
        proxy = net_params.proxy
        proxy_str = serialize_proxy(proxy)
        host, port, protocol = net_params.host, net_params.port, net_params.protocol
        server_str = serialize_server(host, port, protocol)
        # sanitize parameters
        try:
            deserialize_server(serialize_server(host, port, protocol))
            if proxy:
                proxy_modes.index(proxy["mode"]) + 1
                int(proxy['port'])
        except:
            return
        self.config.set_key('auto_connect', net_params.auto_connect, False)
        self.config.set_key("proxy", proxy_str, False)
        self.config.set_key("server", server_str, True)
        # abort if changes were not allowed by config
        if self.config.get('server') != server_str or self.config.get(
                'proxy') != proxy_str:
            return
        self.auto_connect = net_params.auto_connect
        if self.proxy != proxy or self.protocol != protocol:
            # Restart the network defaulting to the given server
            with self.interface_lock:
                self.stop_network()
                self.default_server = server_str
                self.start_network(protocol, proxy)
        elif self.default_server != server_str:
            self.switch_to_interface(server_str)
        else:
            self.switch_lagging_interface()
            self.notify('updated')

    def switch_to_random_interface(self):
        '''Switch to a random connected server other than the current one'''
        servers = self.get_interfaces()  # Those in connected state
        if self.default_server in servers:
            servers.remove(self.default_server)
        if servers:
            self.switch_to_interface(random.choice(servers))

    @with_interface_lock
    def switch_lagging_interface(self):
        '''If auto_connect and lagging, switch interface'''
        if self.server_is_lagging() and self.auto_connect:
            # switch to one that has the correct header (not height)
            header = self.blockchain().read_header(self.get_local_height())

            def filt(x):
                a = x[1].tip_header
                b = header
                assert type(a) is type(b)
                return a == b

            filtered = list(
                map(lambda x: x[0], filter(filt, self.interfaces.items())))
            if filtered:
                choice = random.choice(filtered)
                self.switch_to_interface(choice)

    @with_interface_lock
    def switch_to_interface(self, server):
        '''Switch to server as our interface.  If no connection exists nor
        being opened, start a thread to connect.  The actual switch will
        happen on receipt of the connection notification.  Do nothing
        if server already is our interface.'''
        self.default_server = server
        if server not in self.interfaces:
            self.interface = None
            self.start_interface(server)
            return

        i = self.interfaces[server]
        if self.interface != i:
            self.print_error("switching to", server)
            if self.interface is not None:
                # Stop any current interface in order to terminate subscriptions,
                # and to cancel tasks in interface.group.
                # However, for headers sub, give preference to this interface
                # over unknown ones, i.e. start it again right away.
                old_server = self.interface.server
                self.close_interface(self.interface)
                if len(self.interfaces) <= self.num_server:
                    self.start_interface(old_server)

            self.interface = i
            asyncio.get_event_loop().create_task(
                i.group.spawn(self.request_server_info(i)))
            self.trigger_callback('default_server_changed')
            self.set_status('connected')
            self.notify('updated')
            self.notify('interfaces')

    @with_interface_lock
    def close_interface(self, interface):
        if interface:
            if interface.server in self.interfaces:
                self.interfaces.pop(interface.server)
            if interface.server == self.default_server:
                self.interface = None
            interface.close()

    @with_recent_servers_lock
    def add_recent_server(self, server):
        # list is ordered
        if server in self.recent_servers:
            self.recent_servers.remove(server)
        self.recent_servers.insert(0, server)
        self.recent_servers = self.recent_servers[0:20]
        self.save_recent_servers()

    @with_interface_lock
    def connection_down(self, server):
        '''A connection to server either went down, or was never made.
        We distinguish by whether it is in self.interfaces.'''
        self.disconnected_servers.add(server)
        if server == self.default_server:
            self.set_status('disconnected')
        if server in self.interfaces:
            self.close_interface(self.interfaces[server])
            self.notify('interfaces')

    @aiosafe
    async def new_interface(self, server):
        # todo: get tip first, then decide which checkpoint to use.
        self.add_recent_server(server)

        interface = Interface(self, server, self.config.path, self.proxy)
        timeout = 10 if not self.proxy else 20
        try:
            await asyncio.wait_for(interface.ready, timeout)
        except BaseException as e:
            #import traceback
            #traceback.print_exc()
            self.print_error(interface.server, "couldn't launch because",
                             str(e), str(type(e)))
            self.connection_down(interface.server)
            return
        finally:
            try:
                self.connecting.remove(server)
            except KeyError:
                pass

        with self.interface_lock:
            self.interfaces[server] = interface

        if server == self.default_server:
            self.switch_to_interface(server)

        self.notify('interfaces')

    def init_headers_file(self):
        b = blockchain.blockchains[0]
        filename = b.path()
        length = 80 * len(constants.net.CHECKPOINTS) * 2016
        if not os.path.exists(filename) or os.path.getsize(filename) < length:
            with open(filename, 'wb') as f:
                if length > 0:
                    f.seek(length - 1)
                    f.write(b'\x00')
        with b.lock:
            b.update_size()

    async def get_merkle_for_transaction(self, tx_hash, tx_height):
        return await self.interface.session.send_request(
            'blockchain.transaction.get_merkle', [tx_hash, tx_height])

    def broadcast_transaction_from_non_network_thread(self, tx, timeout=10):
        # note: calling this from the network thread will deadlock it
        fut = asyncio.run_coroutine_threadsafe(
            self.broadcast_transaction(tx, timeout=timeout), self.asyncio_loop)
        return fut.result()

    async def broadcast_transaction(self, tx, timeout=10):
        try:
            out = await self.interface.session.send_request(
                'blockchain.transaction.broadcast', [str(tx)], timeout=timeout)
        except asyncio.TimeoutError as e:
            return False, "error: operation timed out"
        except Exception as e:
            return False, "error: " + str(e)

        if out != tx.txid():
            return False, "error: " + out
        return True, out

    async def request_chunk(self,
                            height,
                            tip,
                            session=None,
                            can_return_early=False):
        if session is None: session = self.interface.session
        index = height // 2016
        if can_return_early and index in self.requested_chunks:
            return
        size = 2016
        if tip is not None:
            size = min(size, tip - index * 2016)
            size = max(size, 0)
        try:
            self.requested_chunks.add(index)
            res = await session.send_request('blockchain.block.headers',
                                             [index * 2016, size])
        finally:
            try:
                self.requested_chunks.remove(index)
            except KeyError:
                pass
        conn = self.blockchain().connect_chunk(index, res['hex'])
        if not conn:
            return conn, 0
        return conn, res['count']

    @with_interface_lock
    def blockchain(self):
        if self.interface and self.interface.blockchain is not None:
            self.blockchain_index = self.interface.blockchain.forkpoint
        return blockchain.blockchains[self.blockchain_index]

    @with_interface_lock
    def get_blockchains(self):
        out = {}
        with self.blockchains_lock:
            blockchain_items = list(blockchain.blockchains.items())
        for k, b in blockchain_items:
            r = list(
                filter(lambda i: i.blockchain == b,
                       list(self.interfaces.values())))
            if r:
                out[k] = r
        return out

    def follow_chain(self, index):
        bc = blockchain.blockchains.get(index)
        if bc:
            self.blockchain_index = index
            self.config.set_key('blockchain_index', index)
            with self.interface_lock:
                interfaces = list(self.interfaces.values())
            for i in interfaces:
                if i.blockchain == bc:
                    self.switch_to_interface(i.server)
                    break
        else:
            raise Exception('blockchain not found', index)

        with self.interface_lock:
            if self.interface:
                net_params = self.get_parameters()
                host, port, protocol = deserialize_server(
                    self.interface.server)
                net_params = net_params._replace(host=host,
                                                 port=port,
                                                 protocol=protocol)
                self.set_parameters(net_params)

    def get_local_height(self):
        return self.blockchain().height()

    def export_checkpoints(self, path):
        # run manually from the console to generate checkpoints
        cp = self.blockchain().get_checkpoints()
        with open(path, 'w', encoding='utf-8') as f:
            f.write(json.dumps(cp, indent=4))

    def start(self, fx=None):
        self.main_taskgroup = TaskGroup()

        async def main():
            self.init_headers_file()
            async with self.main_taskgroup as group:
                await group.spawn(self.maintain_sessions())
                if fx: await group.spawn(fx)

        self._wrapper_thread = threading.Thread(
            target=self.asyncio_loop.run_until_complete, args=(main(), ))
        self._wrapper_thread.start()

    def stop(self):
        asyncio.run_coroutine_threadsafe(
            self.main_taskgroup.cancel_remaining(), self.asyncio_loop)

    def join(self):
        self._wrapper_thread.join(1)

    async def maintain_sessions(self):
        while True:
            while self.socket_queue.qsize() > 0:
                server = self.socket_queue.get()
                asyncio.get_event_loop().create_task(
                    self.new_interface(server))
            remove = []
            for k, i in self.interfaces.items():
                if i.fut.done() and not i.exception:
                    assert False, "interface future should not finish without exception"
                if i.exception:
                    if not i.fut.done():
                        try:
                            i.fut.cancel()
                        except Exception as e:
                            self.print_error('exception while cancelling fut',
                                             e)
                    try:
                        raise i.exception
                    except BaseException as e:
                        self.print_error(i.server, "errored because:", str(e),
                                         str(type(e)))
                    remove.append(k)
            for k in remove:
                self.connection_down(k)

            # nodes
            now = time.time()
            for i in range(self.num_server - len(self.interfaces) -
                           len(self.connecting)):
                self.start_random_interface()
            if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
                self.print_error('network: retrying connections')
                self.disconnected_servers = set([])
                self.nodes_retry_time = now

            # main interface
            if not self.is_connected():
                if self.auto_connect:
                    if not self.is_connecting():
                        self.switch_to_random_interface()
                else:
                    if self.default_server in self.disconnected_servers:
                        if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
                            self.disconnected_servers.remove(
                                self.default_server)
                            self.server_retry_time = now
                    else:
                        self.switch_to_interface(self.default_server)
            else:
                if self.config.is_fee_estimates_update_required():
                    await self.interface.group.spawn(
                        self.request_fee_estimates(self.interface))

            await asyncio.sleep(0.1)
示例#4
0
class Daemon(Logger):

    network: Optional[Network]
    gui_object: Optional[Union['gui.qt.ElectrumGui', 'gui.kivy.ElectrumGui']]

    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.config = config
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        if 'wallet_path' in config.cmdline_options:
            self.logger.warning("Ignoring parameter 'wallet_path' for daemon. "
                                "Use the load_wallet command instead.")
        self.asyncio_loop = asyncio.get_event_loop()
        asyncio_wait_time = 0
        while not self.asyncio_loop.is_running():
            if asyncio_wait_time > 30:
                raise Exception('event loop not running for 30 seconds')
            time.sleep(0.1)
            asyncio_wait_time += 0.1
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        self.current_wallet_path = None
        daemon_jobs = []
        # Setup commands server
        self.commands_server = None
        if listen_jsonrpc:
            self.commands_server = CommandsServer(self, fd)
            daemon_jobs.append(self.commands_server.run())
        # pay server
        self.pay_server = None
        payserver_address = self.config.get_netaddress('payserver_address')
        if not config.get('offline') and payserver_address:
            self.pay_server = PayServer(self, payserver_address)
            daemon_jobs.append(self.pay_server.run())
        if self.network:
            self.network.start(jobs=[self.fx.run])

        self.stopping_soon = threading.Event()
        self.stopped_event = asyncio.Event()
        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs),
                                         self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        self.logger.info("starting taskgroup.")
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait
                                  )  # run forever (until cancel)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.logger.exception("taskgroup died.")
        finally:
            self.logger.info("taskgroup stopped.")
            self.stopping_soon.set()

    def load_wallet(self,
                    path,
                    password,
                    *,
                    manual_upgrades=True,
                    set_current=False) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            if set_current:
                self.current_wallet_path = path
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.upgrade_done:
            storage.backup_old_version()
        if getattr(storage, 'backup_message', None):
            log_backup_msg = ' '.join(storage.backup_message.split('\n'))
            self.logger.info(f'{log_backup_msg}')
            if not self.config.get('detach'):
                util.print_stderr(f'{storage.backup_message}\n')
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        if db.check_unfinished_multisig():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        if set_current:
            self.current_wallet_path = path
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if self.current_wallet_path == path:
            self.current_wallet_path = None
        if not wallet:
            return False
        fut = asyncio.run_coroutine_threadsafe(wallet.stop(),
                                               self.asyncio_loop)
        fut.result()
        return True

    def run_daemon(self):
        try:
            self.stopping_soon.wait()
        except KeyboardInterrupt:
            self.stopping_soon.set()
        self.on_stop()

    async def stop(self):
        self.stopping_soon.set()
        await self.stopped_event.wait()

    def on_stop(self):
        try:
            self.logger.info("on_stop() entered. initiating shutdown")
            if self.gui_object:
                self.gui_object.stop()

            async def stop_async():
                self.logger.info("stopping all wallets")
                async with TaskGroup() as group:
                    for k, wallet in self._wallets.items():
                        await group.spawn(wallet.stop())
                self.logger.info("stopping network and taskgroup")
                async with ignore_after(2):
                    async with TaskGroup() as group:
                        if self.network:
                            await group.spawn(
                                self.network.stop(full_shutdown=True))
                        await group.spawn(self.taskgroup.cancel_remaining())

            fut = asyncio.run_coroutine_threadsafe(stop_async(),
                                                   self.asyncio_loop)
            fut.result()
        finally:
            self.logger.info("removing lockfile")
            remove_lockfile(get_lockfile(self.config))
            self.logger.info("stopped")
            self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set)

    def run_gui(self, config, plugins):
        threading.current_thread().setName('GUI')
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrum_zcash.gui.' + gui_name,
                             fromlist=['electrum_zcash'])
            self.gui_object = gui.ElectrumGui(config, self, plugins)
            self.gui_object.main()
        except BaseException as e:
            self.logger.error(
                f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            self.on_stop()
示例#5
0
class Daemon(Logger):

    network: Optional[Network]
    gui_object: Optional['gui.BaseElectrumSysGui']

    @profiler
    def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
        Logger.__init__(self)
        self.config = config
        self.listen_jsonrpc = listen_jsonrpc
        if fd is None and listen_jsonrpc:
            fd = get_file_descriptor(config)
            if fd is None:
                raise Exception('failed to lock daemon; already running?')
        if 'wallet_path' in config.cmdline_options:
            self.logger.warning("Ignoring parameter 'wallet_path' for daemon. "
                                "Use the load_wallet command instead.")
        self.asyncio_loop = asyncio.get_event_loop()
        self.network = None
        if not config.get('offline'):
            self.network = Network(config, daemon=self)
        self.fx = FxThread(config, self.network)
        self.gui_object = None
        # path -> wallet;   make sure path is standardized.
        self._wallets = {}  # type: Dict[str, Abstract_Wallet]
        daemon_jobs = []
        # Setup commands server
        self.commands_server = None
        if listen_jsonrpc:
            self.commands_server = CommandsServer(self, fd)
            daemon_jobs.append(self.commands_server.run())
        # pay server
        self.pay_server = None
        payserver_address = self.config.get_netaddress('payserver_address')
        if not config.get('offline') and payserver_address:
            self.pay_server = PayServer(self, payserver_address)
            daemon_jobs.append(self.pay_server.run())
        # server-side watchtower
        self.watchtower = None
        watchtower_address = self.config.get_netaddress('watchtower_address')
        if not config.get('offline') and watchtower_address:
            self.watchtower = WatchTowerServer(self.network, watchtower_address)
            daemon_jobs.append(self.watchtower.run)
        if self.network:
            self.network.start(jobs=[self.fx.run])
            # prepare lightning functionality, also load channel db early
            if self.config.get('use_gossip', False):
                self.network.start_gossip()

        self._stop_entered = False
        self._stopping_soon_or_errored = threading.Event()
        self._stopped_event = threading.Event()
        self.taskgroup = TaskGroup()
        asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs), self.asyncio_loop)

    @log_exceptions
    async def _run(self, jobs: Iterable = None):
        if jobs is None:
            jobs = []
        self.logger.info("starting taskgroup.")
        try:
            async with self.taskgroup as group:
                [await group.spawn(job) for job in jobs]
                await group.spawn(asyncio.Event().wait)  # run forever (until cancel)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.logger.exception("taskgroup died.")
            util.send_exception_to_crash_reporter(e)
        finally:
            self.logger.info("taskgroup stopped.")
            # note: we could just "await self.stop()", but in that case GUI users would
            #       not see the exception (especially if the GUI did not start yet).
            self._stopping_soon_or_errored.set()

    def load_wallet(self, path, password, *, manual_upgrades=True) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        # wizard will be launched if we return
        if path in self._wallets:
            wallet = self._wallets[path]
            return wallet
        storage = WalletStorage(path)
        if not storage.file_exists():
            return
        if storage.is_encrypted():
            if not password:
                return
            storage.decrypt(password)
        # read data, pass it to db
        db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
        if db.requires_split():
            return
        if db.requires_upgrade():
            return
        if db.get_action():
            return
        wallet = Wallet(db, storage, config=self.config)
        wallet.start_network(self.network)
        self._wallets[path] = wallet
        return wallet

    def add_wallet(self, wallet: Abstract_Wallet) -> None:
        path = wallet.storage.path
        path = standardize_path(path)
        self._wallets[path] = wallet

    def get_wallet(self, path: str) -> Optional[Abstract_Wallet]:
        path = standardize_path(path)
        return self._wallets.get(path)

    def get_wallets(self) -> Dict[str, Abstract_Wallet]:
        return dict(self._wallets)  # copy

    def delete_wallet(self, path: str) -> bool:
        self.stop_wallet(path)
        if os.path.exists(path):
            os.unlink(path)
            return True
        return False

    def stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        # note: this must not be called from the event loop. # TODO raise if so
        fut = asyncio.run_coroutine_threadsafe(self._stop_wallet(path), self.asyncio_loop)
        return fut.result()

    async def _stop_wallet(self, path: str) -> bool:
        """Returns True iff a wallet was found."""
        path = standardize_path(path)
        wallet = self._wallets.pop(path, None)
        if not wallet:
            return False
        await wallet.stop()
        return True

    def run_daemon(self):
        try:
            self._stopping_soon_or_errored.wait()
        except KeyboardInterrupt:
            asyncio.run_coroutine_threadsafe(self.stop(), self.asyncio_loop).result()
        self._stopped_event.wait()

    async def stop(self):
        if self._stop_entered:
            return
        self._stop_entered = True
        self._stopping_soon_or_errored.set()
        self.logger.info("stop() entered. initiating shutdown")
        try:
            if self.gui_object:
                self.gui_object.stop()
            self.logger.info("stopping all wallets")
            async with TaskGroup() as group:
                for k, wallet in self._wallets.items():
                    await group.spawn(wallet.stop())
            self.logger.info("stopping network and taskgroup")
            async with ignore_after(2):
                async with TaskGroup() as group:
                    if self.network:
                        await group.spawn(self.network.stop(full_shutdown=True))
                    await group.spawn(self.taskgroup.cancel_remaining())
        finally:
            if self.listen_jsonrpc:
                self.logger.info("removing lockfile")
                remove_lockfile(get_lockfile(self.config))
            self.logger.info("stopped")
            self._stopped_event.set()

    def run_gui(self, config, plugins):
        threading.current_thread().name = 'GUI'
        gui_name = config.get('gui', 'qt')
        if gui_name in ['lite', 'classic']:
            gui_name = 'qt'
        self.logger.info(f'launching GUI: {gui_name}')
        try:
            gui = __import__('electrumsys.gui.' + gui_name, fromlist=['electrumsys'])
            self.gui_object = gui.ElectrumSysGui(config=config, daemon=self, plugins=plugins)
            if not self._stop_entered:
                self.gui_object.main()
            else:
                # If daemon.stop() was called before gui_object got created, stop gui now.
                self.gui_object.stop()
        except BaseException as e:
            self.logger.error(f'GUI raised exception: {repr(e)}. shutting down.')
            raise
        finally:
            # app will exit now
            asyncio.run_coroutine_threadsafe(self.stop(), self.asyncio_loop).result()