class GeventedHTTPTransport(AsyncTransport, HTTPTransport): scheme = ['gevent+http', 'gevent+https'] def __init__(self, parsed_url, maximum_outstanding_requests=100, *args, **kwargs): if not has_gevent: raise ImportError('GeventedHTTPTransport requires gevent.') self._lock = Semaphore(maximum_outstanding_requests) super(GeventedHTTPTransport, self).__init__(parsed_url, *args, **kwargs) def async_send(self, data, headers, success_cb, failure_cb): """ Spawn an async request to a remote webserver. """ # this can be optimized by making a custom self.send that does not # read the response since we don't use it. self._lock.acquire() return gevent.spawn( super(GeventedHTTPTransport, self).send, data, headers ).link(lambda x: self._done(x, success_cb, failure_cb)) def _done(self, greenlet, success_cb, failure_cb, *args): self._lock.release() if greenlet.successful(): success_cb() else: failure_cb(greenlet.exception)
def __init__(self, size=None, greenlet_class=None): """ Create a new pool. A pool is like a group, but the maximum number of members is governed by the *size* parameter. :keyword int size: If given, this non-negative integer is the maximum count of active greenlets that will be allowed in this pool. A few values have special significance: * ``None`` (the default) places no limit on the number of greenlets. This is useful when you need to track, but not limit, greenlets, as with :class:`gevent.pywsgi.WSGIServer`. A :class:`Group` may be a more efficient way to achieve the same effect. * ``0`` creates a pool that can never have any active greenlets. Attempting to spawn in this pool will block forever. This is only useful if an application uses :meth:`wait_available` with a timeout and checks :meth:`free_count` before attempting to spawn. """ if size is not None and size < 0: raise ValueError('size must not be negative: %r' % (size, )) Group.__init__(self) self.size = size if greenlet_class is not None: self.greenlet_class = greenlet_class if size is None: self._semaphore = DummySemaphore() else: self._semaphore = Semaphore(size)
class JobSharedLock(object): """ Shared lock for jobs. Each job method can specify a lock which will be shared among all calls for that job and only one job can run at a time for this lock. """ def __init__(self, queue, name): self.queue = queue self.name = name self.jobs = [] self.semaphore = Semaphore() def add_job(self, job): self.jobs.append(job) def get_jobs(self): return self.jobs def remove_job(self, job): self.jobs.remove(job) def locked(self): return self.semaphore.locked() def acquire(self): return self.semaphore.acquire() def release(self): return self.semaphore.release()
class GeventedHTTPTransport(HTTPTransport): scheme = ['gevent+http', 'gevent+https'] def __init__(self, parsed_url, maximum_outstanding_requests=100): if not has_gevent: raise ImportError('GeventedHTTPTransport requires gevent.') self._lock = Semaphore(maximum_outstanding_requests) super(GeventedHTTPTransport, self).__init__(parsed_url) # remove the gevent+ from the protocol, as it is not a real protocol self._url = self._url.split('+', 1)[-1] def send(self, data, headers): """ Spawn an async request to a remote webserver. """ # this can be optimized by making a custom self.send that does not # read the response since we don't use it. self._lock.acquire() return gevent.spawn(super(GeventedHTTPTransport, self).send, data, headers).link(self._done, self) def _done(self, *args): self._lock.release()
class GeventedHTTPTransport(AsyncTransport, HTTPTransport): scheme = ["gevent+http", "gevent+https"] def __init__(self, parsed_url, maximum_outstanding_requests=100): if not has_gevent: raise ImportError("GeventedHTTPTransport requires gevent.") self._lock = Semaphore(maximum_outstanding_requests) super(GeventedHTTPTransport, self).__init__(parsed_url) # remove the gevent+ from the protocol, as it is not a real protocol self._url = self._url.split("+", 1)[-1] def async_send(self, data, headers, success_cb, failure_cb): """ Spawn an async request to a remote webserver. """ # this can be optimized by making a custom self.send that does not # read the response since we don't use it. self._lock.acquire() return gevent.spawn(super(GeventedHTTPTransport, self).send, data, headers).link( lambda x: self._done(x, success_cb, failure_cb) ) def _done(self, greenlet, success_cb, failure_cb, *args): self._lock.release() if greenlet.successful(): success_cb() else: failure_cb(greenlet.value)
def test_iwait_nogarbage(self): sem1 = Semaphore() sem2 = Semaphore() let = gevent.spawn(sem1.release) with gevent.iwait((sem1, sem2)) as iterator: self.assertEqual(sem1, next(iterator)) self.assertEqual(sem2.linkcount(), 1) self.assertEqual(sem2.linkcount(), 0) let.get()
def __init__(self, size=None, greenlet_class=None): if size is not None and size < 0: raise ValueError('size must not be negative: %r' % (size, )) Group.__init__(self) self.size = size if greenlet_class is not None: self.greenlet_class = greenlet_class if size is None: self._semaphore = DummySemaphore() else: self._semaphore = Semaphore(size)
def test_release_twice(self): s = Semaphore() result = [] s.rawlink(lambda s: result.append('a')) s.release() s.rawlink(lambda s: result.append('b')) s.release() gevent.sleep(0.001) self.assertEqual(result, ['a', 'b'])
def test_release_twice(self): s = Semaphore() result = [] s.rawlink(lambda s: result.append('a')) s.release() s.rawlink(lambda s: result.append('b')) s.release() gevent.sleep(0.001) # The order, though, is not guaranteed. self.assertEqual(sorted(result), ['a', 'b'])
class Pool(Group): def __init__(self, size=None, greenlet_class=None): if size is not None and size < 0: raise ValueError("size must not be negative: %r" % (size,)) Group.__init__(self) self.size = size if greenlet_class is not None: self.greenlet_class = greenlet_class if size is None: self._semaphore = DummySemaphore() else: self._semaphore = Semaphore(size) def wait_available(self): self._semaphore.wait() def full(self): return self.free_count() <= 0 def free_count(self): if self.size is None: return 1 return max(0, self.size - len(self)) def add(self, greenlet): self._semaphore.acquire() try: Group.add(self, greenlet) except: self._semaphore.release() raise def _discard(self, greenlet): Group._discard(self, greenlet) self._semaphore.release()
class GeventConnectionPool(psycopg2.pool.AbstractConnectionPool): def __init__(self, minconn, maxconn, *args, **kwargs): self.semaphore = Semaphore(maxconn) psycopg2.pool.AbstractConnectionPool.__init__(self, minconn, maxconn, *args, **kwargs) def getconn(self, *args, **kwargs): self.semaphore.acquire() return self._getconn(*args, **kwargs) def putconn(self, *args, **kwargs): self._putconn(*args, **kwargs) self.semaphore.release() # Not sure what to do about this one... closeall = psycopg2.pool.AbstractConnectionPool._closeall
def __init__(self, parsed_url, maximum_outstanding_requests=100, *args, **kwargs): if not has_gevent: raise ImportError('GeventedHTTPTransport requires gevent.') self._lock = Semaphore(maximum_outstanding_requests) super(GeventedHTTPTransport, self).__init__(parsed_url, *args, **kwargs)
def __init__(self, logger, zkservers, svc_name, inst, watchers={}, zpostfix="", freq=10): gevent.Greenlet.__init__(self) self._svc_name = svc_name self._inst = inst self._zk_server = zkservers # initialize logging and other stuff if logger is None: logging.basicConfig() self._logger = logging else: self._logger = logger self._conn_state = None self._sandesh_connection_info_update(status='INIT', message='') self._zkservers = zkservers self._zk = None self._pubinfo = None self._publock = Semaphore() self._watchers = watchers self._wchildren = {} self._pendingcb = set() self._zpostfix = zpostfix self._basepath = "/analytics-discovery-" + self._zpostfix self._reconnect = None self._freq = freq
def __init__(self, url): self.sock_file = get_data(url) gevent.spawn(self.start_reading_bytes) self.file_buffer = array.array('c') self.file_buffer_ptr = 0 # gevent.spawn(self.start_reading_dummy_bytes) self.evt = Semaphore(value=0)
def __init__(self, uplink=None, downlink=None): self.log = logging.getLogger("{module}.{name}".format( module=self.__class__.__module__, name=self.__class__.__name__)) self.downlink = downlink self.uplink = uplink self.recv_callback = None self.context = zmq.Context() self.poller = zmq.Poller() self.ul_socket = self.context.socket(zmq.SUB) # one SUB socket for uplink communication over topics if sys.version_info.major >= 3: self.ul_socket.setsockopt_string(zmq.SUBSCRIBE, "NEW_NODE") self.ul_socket.setsockopt_string(zmq.SUBSCRIBE, "NODE_EXIT") else: self.ul_socket.setsockopt(zmq.SUBSCRIBE, "NEW_NODE") self.ul_socket.setsockopt(zmq.SUBSCRIBE, "NODE_EXIT") self.downlinkSocketLock = Semaphore(value=1) self.dl_socket = self.context.socket(zmq.PUB) # one PUB socket for downlink communication over topics #register UL socket in poller self.poller.register(self.ul_socket, zmq.POLLIN) self.importedPbClasses = {}
def __init__(self, provider, n, **kwargs): super(LimitEngine, self).__init__(provider, **kwargs) self.conf = self.provider.configurations()[0] self.n = n self.services = [] self.num_unscheduled = -1 self.le_lock = Semaphore()
def __init__(self, rabbit_server, rabbit_port, rabbit_user, rabbit_password, notification_level, ironic_notif_mgr_obj, **kwargs): self._rabbit_port = rabbit_port self._rabbit_user = rabbit_user self._rabbit_password = rabbit_password self._rabbit_hosts = self._parse_rabbit_hosts(rabbit_server) self._rabbit_ip = self._rabbit_hosts[0]["host"] self._notification_level = notification_level self._ironic_notification_manager = ironic_notif_mgr_obj self._conn_lock = Semaphore() # Register a handler for SIGTERM so that we can release the lock # Without it, it can take several minutes before new master is elected # If any app using this wants to register their own sigterm handler, # then we will have to modify this function to perhaps take an argument # gevent.signal(signal.SIGTERM, self.sigterm_handler) self._url = "amqp://%s:%s@%s:%s/" % (self._rabbit_user, self._rabbit_password, self._rabbit_ip, self._rabbit_port) msg = "Initializing RabbitMQ connection, urls %s" % self._url # self._conn_state = ConnectionStatus.INIT self._conn = kombu.Connection(self._url) self._exchange = self._set_up_exchange() self._queues = [] self._queues = self._set_up_queues(self._notification_level) if not self._queues: exit()
def __init__( self, rabbit_ip, rabbit_port, rabbit_user, rabbit_password, rabbit_vhost, rabbit_ha_mode, q_name, subscribe_cb, logger, ): self._rabbit_ip = rabbit_ip self._rabbit_port = rabbit_port self._rabbit_user = rabbit_user self._rabbit_password = rabbit_password self._rabbit_vhost = rabbit_vhost self._subscribe_cb = subscribe_cb self._logger = logger self._publish_queue = Queue() self._conn_lock = Semaphore() self.obj_upd_exchange = kombu.Exchange("vnc_config.object-update", "fanout", durable=False) # Register a handler for SIGTERM so that we can release the lock # Without it, it can take several minutes before new master is elected # If any app using this wants to register their own sigterm handler, # then we will have to modify this function to perhaps take an argument gevent.signal(signal.SIGTERM, self.sigterm_handler)
class Puppet(object): def __init__(self): self.results = {} self.status = ("Idle", None) self.worker = None self.operator = None self.lock = Semaphore() def receive_info(self): while True: (info_id, info) = self.task_queue.get() if info_id > 0: self.results[info_id].set(info) elif info_id == -1: self.status = info def submit_task(self, tid, task_info): self.results[tid] = AsyncResult() self.lock.acquire() self.task_queue.put((tid, task_info)) self.lock.release() def fetch_result(self, tid): res = self.results[tid].get() del self.results[tid] return res def hire_worker(self): tq_worker, self.task_queue = gipc.pipe(duplex=True) self.worker = gipc.start_process(target=process_worker, args=(tq_worker,)) self.operator = gevent.spawn(self.receive_info) def fire_worker(self): self.worker.terminate() self.operator.kill() def get_attr(self, attr): if attr == "cpu": return psutil.cpu_percent(interval=None) elif attr == "memory": return psutil.virtual_memory().percent else: return getattr(self, attr, "No such attribute") def current_tasks(self): return self.results.keys()
def __init__(self, parent, io_loop=None): self._parent = parent self._closing = False self._user_queries = { } self._cursor_cache = { } self._write_mutex = Semaphore() self._socket = None
def __init__(self, parsed_url, maximum_outstanding_requests=100): if not has_gevent: raise ImportError('GeventedHTTPTransport requires gevent.') self._lock = Semaphore(maximum_outstanding_requests) super(GeventedHTTPTransport, self).__init__(parsed_url) # remove the gevent+ from the protocol, as it is not a real protocol self._url = self._url.split('+', 1)[-1]
def test_renamed_label_refresh(db, default_account, thread, message, imapuid, folder, mock_imapclient, monkeypatch): # Check that imapuids see their labels refreshed after running # the LabelRenameHandler. msg_uid = imapuid.msg_uid uid_dict = {msg_uid: GmailFlags((), ('stale label',), ('23',))} update_metadata(default_account.id, folder.id, folder.canonical_name, uid_dict, db.session) new_flags = {msg_uid: {'FLAGS': ('\\Seen',), 'X-GM-LABELS': ('new label',), 'MODSEQ': ('23',)}} mock_imapclient._data['[Gmail]/All mail'] = new_flags mock_imapclient.add_folder_data(folder.name, new_flags) monkeypatch.setattr(MockIMAPClient, 'search', lambda x, y: [msg_uid]) semaphore = Semaphore(value=1) rename_handler = LabelRenameHandler(default_account.id, default_account.namespace.id, 'new label', semaphore) # Acquire the semaphore to check that LabelRenameHandlers block if # the semaphore is in-use. semaphore.acquire() rename_handler.start() # Wait 10 secs and check that the data hasn't changed. gevent.sleep(10) labels = list(imapuid.labels) assert len(labels) == 1 assert labels[0].name == 'stale label' semaphore.release() rename_handler.join() db.session.refresh(imapuid) # Now check that the label got updated. labels = list(imapuid.labels) assert len(labels) == 1 assert labels[0].name == 'new label'
class DatagramServer(BaseServer): """A UDP server""" reuse_addr = DEFAULT_REUSE_ADDR def __init__(self, *args, **kwargs): # The raw (non-gevent) socket, if possible self._socket = None BaseServer.__init__(self, *args, **kwargs) from gevent.lock import Semaphore self._writelock = Semaphore() def init_socket(self): if not hasattr(self, 'socket'): # FIXME: clean up the socket lifetime # pylint:disable=attribute-defined-outside-init self.socket = self.get_listener(self.address, self.family) self.address = self.socket.getsockname() self._socket = self.socket try: self._socket = self._socket._sock except AttributeError: pass @classmethod def get_listener(cls, address, family=None): return _udp_socket(address, reuse_addr=cls.reuse_addr, family=family) def do_read(self): try: data, address = self._socket.recvfrom(8192) except _socket.error as err: if err.args[0] == EWOULDBLOCK: return raise return data, address def sendto(self, *args): self._writelock.acquire() try: self.socket.sendto(*args) finally: self._writelock.release()
def __init__(self, actor_config, location="", rules={}): Actor.__init__(self, actor_config) self.pool.createQueue("inbox") self.pool.createQueue("nomatch") self.registerConsumer(self.consume, "inbox") self.__active_rules = {} self.match = MatchRules() self.rule_lock = Semaphore()
def __init__(self, s_id, conf): self.s_id = s_id self.conf = conf self.start_time = None self.finish_time = None self.puppet = None self.lock = Semaphore() self.queue = set() self._status = "Unknown" self.started = False
class DatagramServer(BaseServer): """A UDP server""" reuse_addr = DEFAULT_REUSE_ADDR def __init__(self, *args, **kwargs): BaseServer.__init__(self, *args, **kwargs) from gevent.lock import Semaphore self._writelock = Semaphore() def init_socket(self): if not hasattr(self, 'socket'): self.socket = self.get_listener(self.address, self.family) self.address = self.socket.getsockname() self._socket = self.socket try: self._socket = self._socket._sock except AttributeError: pass @classmethod def get_listener(self, address, family=None): return _udp_socket(address, reuse_addr=self.reuse_addr, family=family) def do_read(self): try: data, address = self._socket.recvfrom(8192) except _socket.error: err = sys.exc_info()[1] if err.args[0] == EWOULDBLOCK: return raise return data, address def sendto(self, *args): self._writelock.acquire() try: self.socket.sendto(*args) finally: self._writelock.release()
class LimitEngine(EngineBase): def __init__(self, provider, n, **kwargs): super(LimitEngine, self).__init__(provider, **kwargs) self.conf = self.provider.configurations()[0] self.n = n self.services = [] self.num_unscheduled = -1 self.le_lock = Semaphore() def before_eval(self): self.num_unscheduled = len(self.dag) def after_eval(self): for s in self.services: self.provider.stop_service(s) self.services = [] def which_service(self, task): self.num_unscheduled -= 1 for s in self.services: if len(s.tasks) == 0: return s if len(self.services) < self.n: s = self.provider.start_service(len(self.services) + 1, self.conf) self.services.append(s) else: s = min(self.services, key=lambda x: len(x.tasks)) return s def after_task(self, task, service): self.le_lock.acquire() if self.num_unscheduled == 0 and len(service.tasks) == 0: self.provider.stop_service(service) for s in self.services: if len(s.tasks) == 0: self.provider.stop_service(s) self.le_lock.release() def current_services(self): return self.services
def _test_atomic(): # NOTE: Nested context by comma is not available in Python 2.6. # o -- No gevent. with lets.atomic(): 1 + 2 + 3 # x -- gevent.sleep() with pytest.raises(AssertionError): with lets.atomic(): gevent.sleep(0.1) # x -- gevent.sleep() with 0 seconds. with pytest.raises(AssertionError): with lets.atomic(): gevent.sleep(0) # o -- Greenlet.spawn() with lets.atomic(): gevent.spawn(gevent.sleep, 0.1) # x -- Greenlet.join() with pytest.raises(AssertionError): with lets.atomic(): g = gevent.spawn(gevent.sleep, 0.1) g.join() # x -- Greenlet.get() with pytest.raises(AssertionError): with lets.atomic(): g = gevent.spawn(gevent.sleep, 0.1) g.get() # x -- gevent.joinall() with pytest.raises(AssertionError): with lets.atomic(): g = gevent.spawn(gevent.sleep, 0.1) gevent.joinall([g]) # o -- Event.set(), AsyncResult.set() with lets.atomic(): Event().set() AsyncResult().set() # x -- Event.wait() with pytest.raises(AssertionError): with lets.atomic(): Event().wait() # x -- Event.wait() with pytest.raises(AssertionError): with lets.atomic(): AsyncResult().wait() # x -- Channel.put() with pytest.raises(AssertionError): with lets.atomic(): ch = Channel() ch.put(123) # o -- First Semaphore.acquire() with lets.atomic(): lock = Semaphore() lock.acquire() # x -- Second Semaphore.acquire() with pytest.raises(AssertionError): with lets.atomic(): lock = Semaphore() lock.acquire() lock.acquire() # Back to normal. gevent.sleep(1)
def __init__(self, **kwargs): self.account = self # needed for InputFunctions.solve_* functions self.multi_account = False self.lock = Semaphore() self.check_pool = VariableSizePool(size=self.max_check_tasks) self.download_pool = VariableSizePool(size=self.max_download_tasks) self.search_pool = VariableSizePool(size=10) self.reset() for k, v in kwargs.iteritems(): setattr(self, k, v)
class GeventedHTTPTransport(AsyncTransport, HTTPTransport): scheme = ['gevent+http', 'gevent+https'] def __init__(self, parsed_url, maximum_outstanding_requests=100, oneway=False): if not has_gevent: raise ImportError('GeventedHTTPTransport requires gevent.') self._lock = Semaphore(maximum_outstanding_requests) self._oneway = oneway == 'true' super(GeventedHTTPTransport, self).__init__(parsed_url) # remove the gevent+ from the protocol, as it is not a real protocol self._url = self._url.split('+', 1)[-1] def async_send(self, data, headers, success_cb, failure_cb): """ Spawn an async request to a remote webserver. """ if not self._oneway: self._lock.acquire() return gevent.spawn( super(GeventedHTTPTransport, self).send, data, headers ).link(lambda x: self._done(x, success_cb, failure_cb)) else: req = urllib2.Request(self._url, headers=headers) return urlopen(url=req, data=data, timeout=self.timeout, verify_ssl=self.verify_ssl, ca_certs=self.ca_certs) def _done(self, greenlet, success_cb, failure_cb, *args): self._lock.release() if greenlet.successful(): success_cb() else: failure_cb(greenlet.exception)
def __init__(self, *args, **kwargs): BaseServer.__init__(self, *args, **kwargs) from gevent.lock import Semaphore self._writelock = Semaphore()
def __init__(self, sck, server): self._sck = sck self._rlock = Semaphore() self._wlock = Semaphore() self._server = server self._living_controllers = {}
class AnalyticsDiscovery(gevent.Greenlet): def _sandesh_connection_info_update(self, status, message): new_conn_state = getattr(ConnectionStatus, status) ConnectionState.update(conn_type=ConnectionType.ZOOKEEPER, name=self._svc_name, status=new_conn_state, server_addrs=self._zk_server.split(','), message=message) if (self._conn_state and self._conn_state != ConnectionStatus.DOWN and new_conn_state == ConnectionStatus.DOWN): msg = 'Connection to Zookeeper down: %s' % (message) self._logger.error(msg) if (self._conn_state and self._conn_state != new_conn_state and new_conn_state == ConnectionStatus.UP): msg = 'Connection to Zookeeper ESTABLISHED' self._logger.error(msg) self._conn_state = new_conn_state # end _sandesh_connection_info_update def _zk_listen(self, state): self._logger.error("Analytics Discovery listen %s" % str(state)) if state == KazooState.CONNECTED: self._sandesh_connection_info_update( status='UP', message='Connection to Zookeeper re-established') self._logger.error("Analytics Discovery to publish %s" % str(self._pubinfo)) self._reconnect = True elif state == KazooState.LOST: self._logger.error("Analytics Discovery connection LOST") # Lost the session with ZooKeeper Server # Best of option we have is to exit the process and restart all # over again self._sandesh_connection_info_update( status='DOWN', message='Connection to Zookeeper lost') os._exit(2) elif state == KazooState.SUSPENDED: self._logger.error("Analytics Discovery connection SUSPENDED") # Update connection info self._sandesh_connection_info_update( status='INIT', message='Connection to zookeeper lost. Retrying') def _zk_datawatch(self, watcher, child, data, stat, event="unknown"): self._logger.error(\ "Analytics Discovery %s ChildData : child %s, data %s, event %s" % \ (watcher, child, data, event)) if data: data_dict = json.loads(data) self._wchildren[watcher][child] = OrderedDict( sorted(data_dict.items())) else: if child in self._wchildren[watcher]: del self._wchildren[watcher][child] if self._data_watchers[watcher]: self._pendingcb.add(watcher) def _zk_watcher(self, watcher, children): self._logger.error("Analytics Discovery Watcher %s Children %s" % (watcher, children)) self._reconnect = True def __init__(self, logger, zkservers, svc_name, inst, data_watchers={}, child_watchers={}, zpostfix="", freq=10): gevent.Greenlet.__init__(self) self._svc_name = svc_name self._inst = inst self._zk_server = zkservers # initialize logging and other stuff if logger is None: logging.basicConfig() self._logger = logging else: self._logger = logger self._conn_state = None self._sandesh_connection_info_update( status='INIT', message='Connection to Zookeeper initialized') self._zkservers = zkservers self._zk = None self._pubinfo = None self._publock = Semaphore() self._data_watchers = data_watchers self._child_watchers = child_watchers self._wchildren = {} self._pendingcb = set() self._zpostfix = zpostfix self._basepath = "/analytics-discovery-" + self._zpostfix self._reconnect = None self._freq = freq def publish(self, pubinfo): # This function can be called concurrently by the main AlarmDiscovery # processing loop as well as by clients. # It is NOT re-entrant self._publock.acquire() self._pubinfo = pubinfo if self._conn_state == ConnectionStatus.UP: try: self._logger.error("ensure %s" % (self._basepath + "/" + self._svc_name)) self._logger.error("zk state %s (%s)" % (self._zk.state, self._zk.client_state)) self._zk.ensure_path(self._basepath + "/" + self._svc_name) self._logger.error("check for %s/%s/%s" % \ (self._basepath, self._svc_name, self._inst)) if pubinfo is not None: if self._zk.exists("%s/%s/%s" % \ (self._basepath, self._svc_name, self._inst)): self._zk.set("%s/%s/%s" % \ (self._basepath, self._svc_name, self._inst), self._pubinfo) else: self._zk.create("%s/%s/%s" % \ (self._basepath, self._svc_name, self._inst), self._pubinfo, ephemeral=True) else: if self._zk.exists("%s/%s/%s" % \ (self._basepath, self._svc_name, self._inst)): self._logger.error("withdrawing published info!") self._zk.delete("%s/%s/%s" % \ (self._basepath, self._svc_name, self._inst)) except Exception as ex: template = "Exception {0} in AnalyticsDiscovery publish. Args:\n{1!r}" messag = template.format(type(ex).__name__, ex.args) self._logger.error("%s : traceback %s for %s info %s" % \ (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo))) self._sandesh_connection_info_update( status='DOWN', message='Reconnect to Zookeeper to handle exception') self._reconnect = True else: self._logger.error("Analytics Discovery cannot publish while down") self._publock.release() def _run(self): while True: self._logger.error("Analytics Discovery zk start") self._zk = KazooClient(hosts=self._zkservers) self._zk.add_listener(self._zk_listen) try: self._zk.start() while self._conn_state != ConnectionStatus.UP: gevent.sleep(1) break except Exception as e: # Update connection info self._sandesh_connection_info_update(status='DOWN', message=str(e)) self._zk.remove_listener(self._zk_listen) try: self._zk.stop() self._zk.close() except Exception as ex: template = "Exception {0} in AnalyticsDiscovery zk stop/close. Args:\n{1!r}" messag = template.format(type(ex).__name__, ex.args) self._logger.error("%s : traceback %s for %s" % \ (messag, traceback.format_exc(), self._svc_name)) finally: self._zk = None gevent.sleep(1) try: # Update connection info self._sandesh_connection_info_update( status='UP', message='Connection to Zookeeper established') self._reconnect = False # Done connecting to ZooKeeper for wk in self._data_watchers.keys(): self._zk.ensure_path(self._basepath + "/" + wk) self._wchildren[wk] = {} self._zk.ChildrenWatch(self._basepath + "/" + wk, partial(self._zk_watcher, wk)) for wk in self._child_watchers.keys(): self._zk.ensure_path(self._basepath + "/" + wk) self._zk.ChildrenWatch(self._basepath + "/" + wk, self._child_watchers[wk]) # Trigger the initial publish self._reconnect = True while True: try: if not self._reconnect: pending_list = list(self._pendingcb) self._pendingcb = set() for wk in pending_list: if self._data_watchers[wk]: self._data_watchers[wk](\ sorted(self._wchildren[wk].values())) # If a reconnect happens during processing, don't lose it while self._reconnect: self._logger.error("Analytics Discovery %s reconnect" \ % self._svc_name) self._reconnect = False self._pendingcb = set() self.publish(self._pubinfo) for wk in self._data_watchers.keys(): self._zk.ensure_path(self._basepath + "/" + wk) children = self._zk.get_children(self._basepath + "/" + wk) old_children = set(self._wchildren[wk].keys()) new_children = set(children) # Remove contents for the children who are gone # (DO NOT remove the watch) for elem in old_children - new_children: del self._wchildren[wk][elem] # Overwrite existing children, or create new ones for elem in new_children: # Create a watch for new children if elem not in self._wchildren[wk]: self._zk.DataWatch(self._basepath + "/" + \ wk + "/" + elem, partial(self._zk_datawatch, wk, elem)) data_str, _ = self._zk.get(\ self._basepath + "/" + wk + "/" + elem) data_dict = json.loads(data_str) self._wchildren[wk][elem] = \ OrderedDict(sorted(data_dict.items())) self._logger.error(\ "Analytics Discovery %s ChildData : child %s, data %s, event %s" % \ (wk, elem, self._wchildren[wk][elem], "GET")) if self._data_watchers[wk]: self._data_watchers[wk](sorted( self._wchildren[wk].values())) gevent.sleep(self._freq) except gevent.GreenletExit: self._logger.error("Exiting AnalyticsDiscovery for %s" % \ self._svc_name) self._zk.remove_listener(self._zk_listen) gevent.sleep(1) try: self._zk.stop() except: self._logger.error("Stopping kazooclient failed") else: self._logger.error("Stopping kazooclient successful") try: self._zk.close() except: self._logger.error("Closing kazooclient failed") else: self._logger.error("Closing kazooclient successful") break except Exception as ex: template = "Exception {0} in AnalyticsDiscovery reconnect. Args:\n{1!r}" messag = template.format(type(ex).__name__, ex.args) self._logger.error("%s : traceback %s for %s info %s" % \ (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo))) self._reconnect = True except Exception as ex: template = "Exception {0} in AnalyticsDiscovery run. Args:\n{1!r}" messag = template.format(type(ex).__name__, ex.args) self._logger.error("%s : traceback %s for %s info %s" % \ (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo))) raise SystemExit
class ScpSever(): def __init__(self,conn): self.conn = conn self.closed = False self.connerr = None self.conn_mutex = RLock() self.conn_cond = Semaphore(0) def read(self,size): conn, err = self.acquire_conn() if err: #conn is closed return '',err data,err = conn.read(size) if err: #freeze, waiting for reuse conn.freeze() self.connerr = err return data, None def write(self,data): conn, err = self.acquire_conn() if err: #conn is closed return '',err err = self.conn.write(data) if err: #freeze, waiting for reuse conn.freeze() self.connerr = err return None @with_goto def close(self): self.conn_mutex.acquire() if self.closed: goto .end self.conn.close() self.closed = True self.connerr = error label .end self.conn_cond.release() self.conn_mutex.release() return self.connerr #超时计数 def _star_wait(self): reuse_timeout = int(config['listen']['reuse_time']) self.time_task = Timer(reuse_timeout,self.close) self.time_task.start() def _stop_wait(self): self.time_task.cancel() def _cond_wait(self): self.conn_mutex.release() self.conn_cond.acquire() self.conn_mutex.acquire() def acquire_conn(self): self.conn_mutex.acquire() conn = None connerr = None while True: if self.closed: connerr = self.connerr break elif self.connerr: self._star_wait() self._cond_wait() self._stop_wait() else: conn = self.conn break self.conn_mutex.release() return conn, connerr @with_goto def replace_conn(self, conn): self.conn_mutex.acquire() ret = False if self.closed: goto .end #close old conn self.conn.close() #set new status self.conn = conn self.connerr = None ret = True label .end self.conn_cond.release() self.conn_mutex.release() return ret
class Rotkehlchen(): def __init__(self, args: argparse.Namespace) -> None: """Initialize the Rotkehlchen object May Raise: - SystemPermissionError if the given data directory's permissions are not correct. """ self.lock = Semaphore() self.lock.acquire() # Can also be None after unlock if premium credentials did not # authenticate or premium server temporarily offline self.premium: Optional[Premium] = None self.user_is_logged_in: bool = False configure_logging(args) self.sleep_secs = args.sleep_secs if args.data_dir is None: self.data_dir = default_data_directory() else: self.data_dir = Path(args.data_dir) if not os.access(self.data_dir, os.W_OK | os.R_OK): raise SystemPermissionError( f'The given data directory {self.data_dir} is not readable or writable', ) self.args = args self.msg_aggregator = MessagesAggregator() self.greenlet_manager = GreenletManager( msg_aggregator=self.msg_aggregator) self.exchange_manager = ExchangeManager( msg_aggregator=self.msg_aggregator) # Initialize the AssetResolver singleton AssetResolver(data_directory=self.data_dir) self.data = DataHandler(self.data_dir, self.msg_aggregator) self.cryptocompare = Cryptocompare(data_directory=self.data_dir, database=None) self.coingecko = Coingecko() self.icon_manager = IconManager(data_dir=self.data_dir, coingecko=self.coingecko) self.greenlet_manager.spawn_and_track( after_seconds=None, task_name='periodically_query_icons_until_all_cached', method=self.icon_manager.periodically_query_icons_until_all_cached, batch_size=ICONS_BATCH_SIZE, sleep_time_secs=ICONS_QUERY_SLEEP, ) # Initialize the Inquirer singleton Inquirer( data_dir=self.data_dir, cryptocompare=self.cryptocompare, coingecko=self.coingecko, ) # Keeps how many trades we have found per location. Used for free user limiting self.actions_per_location: Dict[str, Dict[Location, int]] = { 'trade': defaultdict(int), 'asset_movement': defaultdict(int), } self.lock.release() self.shutdown_event = gevent.event.Event() def reset_after_failed_account_creation_or_login(self) -> None: """If the account creation or login failed make sure that the Rotki instance is clear Tricky instances are when after either failed premium credentials or user refusal to sync premium databases we relogged in. """ self.cryptocompare.db = None def unlock_user( self, user: str, password: str, create_new: bool, sync_approval: Literal['yes', 'no', 'unknown'], premium_credentials: Optional[PremiumCredentials], initial_settings: Optional[ModifiableDBSettings] = None, ) -> None: """Unlocks an existing user or creates a new one if `create_new` is True May raise: - PremiumAuthenticationError if the password can't unlock the database. - AuthenticationError if premium_credentials are given and are invalid or can't authenticate with the server - DBUpgradeError if the rotki DB version is newer than the software or there is a DB upgrade and there is an error. - SystemPermissionError if the directory or DB file can not be accessed """ log.info( 'Unlocking user', user=user, create_new=create_new, sync_approval=sync_approval, initial_settings=initial_settings, ) # unlock or create the DB self.password = password self.user_directory = self.data.unlock(user, password, create_new, initial_settings) self.data_importer = DataImporter(db=self.data.db) self.last_data_upload_ts = self.data.db.get_last_data_upload_ts() self.premium_sync_manager = PremiumSyncManager(data=self.data, password=password) # set the DB in the external services instances that need it self.cryptocompare.set_database(self.data.db) # Anything that was set above here has to be cleaned in case of failure in the next step # by reset_after_failed_account_creation_or_login() try: self.premium = self.premium_sync_manager.try_premium_at_start( given_premium_credentials=premium_credentials, username=user, create_new=create_new, sync_approval=sync_approval, ) except PremiumAuthenticationError: # Reraise it only if this is during the creation of a new account where # the premium credentials were given by the user if create_new: raise self.msg_aggregator.add_error( 'Tried to synchronize the database from remote but the local password ' 'does not match the one the remote DB has. Please change the password ' 'to be the same as the password of the account you want to sync from ', ) # else let's just continue. User signed in succesfully, but he just # has unauthenticable/invalid premium credentials remaining in his DB settings = self.get_settings() self.greenlet_manager.spawn_and_track( after_seconds=None, task_name='submit_usage_analytics', method=maybe_submit_usage_analytics, should_submit=settings.submit_usage_analytics, ) self.etherscan = Etherscan(database=self.data.db, msg_aggregator=self.msg_aggregator) historical_data_start = settings.historical_data_start eth_rpc_endpoint = settings.eth_rpc_endpoint # Initialize the price historian singleton PriceHistorian( data_directory=self.data_dir, history_date_start=historical_data_start, cryptocompare=self.cryptocompare, ) self.accountant = Accountant( db=self.data.db, user_directory=self.user_directory, msg_aggregator=self.msg_aggregator, create_csv=True, ) # Initialize the rotkehlchen logger LoggingSettings(anonymized_logs=settings.anonymized_logs) exchange_credentials = self.data.db.get_exchange_credentials() self.exchange_manager.initialize_exchanges( exchange_credentials=exchange_credentials, database=self.data.db, ) # Initialize blockchain querying modules ethereum_manager = EthereumManager( ethrpc_endpoint=eth_rpc_endpoint, etherscan=self.etherscan, database=self.data.db, msg_aggregator=self.msg_aggregator, greenlet_manager=self.greenlet_manager, connect_at_start=ETHEREUM_NODES_TO_CONNECT_AT_START, ) Inquirer().inject_ethereum(ethereum_manager) self.chain_manager = ChainManager( blockchain_accounts=self.data.db.get_blockchain_accounts(), ethereum_manager=ethereum_manager, msg_aggregator=self.msg_aggregator, database=self.data.db, greenlet_manager=self.greenlet_manager, premium=self.premium, eth_modules=settings.active_modules, ) self.trades_historian = TradesHistorian( user_directory=self.user_directory, db=self.data.db, msg_aggregator=self.msg_aggregator, exchange_manager=self.exchange_manager, chain_manager=self.chain_manager, ) self.user_is_logged_in = True log.debug('User unlocking complete') def logout(self) -> None: if not self.user_is_logged_in: return user = self.data.username log.info( 'Logging out user', user=user, ) self.greenlet_manager.clear() del self.chain_manager self.exchange_manager.delete_all_exchanges() # Reset rotkehlchen logger to default LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS) del self.accountant del self.trades_historian del self.data_importer if self.premium is not None: del self.premium self.data.logout() self.password = '' self.cryptocompare.unset_database() # Make sure no messages leak to other user sessions self.msg_aggregator.consume_errors() self.msg_aggregator.consume_warnings() self.user_is_logged_in = False log.info( 'User successfully logged out', user=user, ) def set_premium_credentials(self, credentials: PremiumCredentials) -> None: """ Sets the premium credentials for Rotki Raises PremiumAuthenticationError if the given key is rejected by the Rotkehlchen server """ log.info('Setting new premium credentials') if self.premium is not None: self.premium.set_credentials(credentials) else: self.premium = premium_create_and_verify(credentials) self.data.db.set_rotkehlchen_premium(credentials) def delete_premium_credentials(self) -> Tuple[bool, str]: """Deletes the premium credentials for Rotki""" msg = '' success = self.data.db.del_rotkehlchen_premium() if success is False: msg = 'The database was unable to delete the Premium keys for the logged-in user' self.deactivate_premium_status() return success, msg def deactivate_premium_status(self) -> None: """Deactivate premium in the current session""" self.premium = None self.premium_sync_manager.premium = None self.chain_manager.deactivate_premium_status() def start(self) -> gevent.Greenlet: return gevent.spawn(self.main_loop) def main_loop(self) -> None: """Rotki main loop that fires often and manages many different tasks Each task remembers the last time it run successfully and know how often it should run. So each task manages itself. """ # super hacky -- organize better when recurring tasks are implemented # https://github.com/rotki/rotki/issues/1106 xpub_derivation_scheduled = False while self.shutdown_event.wait(MAIN_LOOP_SECS_DELAY) is not True: if self.user_is_logged_in: log.debug('Main loop start') self.premium_sync_manager.maybe_upload_data_to_server() if not xpub_derivation_scheduled: # 1 minute in the app's startup try to derive new xpub addresses self.greenlet_manager.spawn_and_track( after_seconds=60.0, task_name='Derive new xpub addresses', method=XpubManager( self.chain_manager).check_for_new_xpub_addresses, ) xpub_derivation_scheduled = True log.debug('Main loop end') def get_blockchain_account_data( self, blockchain: SupportedBlockchain, ) -> Union[List[BlockchainAccountData], Dict[str, Any]]: account_data = self.data.db.get_blockchain_account_data(blockchain) if blockchain != SupportedBlockchain.BITCOIN: return account_data xpub_data = self.data.db.get_bitcoin_xpub_data() addresses_to_account_data = {x.address: x for x in account_data} address_to_xpub_mappings = self.data.db.get_addresses_to_xpub_mapping( list(addresses_to_account_data.keys()), # type: ignore ) xpub_mappings: Dict['XpubData', List[BlockchainAccountData]] = {} for address, xpub_entry in address_to_xpub_mappings.items(): if xpub_entry not in xpub_mappings: xpub_mappings[xpub_entry] = [] xpub_mappings[xpub_entry].append( addresses_to_account_data[address]) data: Dict[str, Any] = {'standalone': [], 'xpubs': []} # Add xpub data for xpub_entry in xpub_data: data_entry = xpub_entry.serialize() addresses = xpub_mappings.get(xpub_entry, None) data_entry['addresses'] = addresses if addresses and len( addresses) != 0 else None data['xpubs'].append(data_entry) # Add standalone addresses for account in account_data: if account.address not in address_to_xpub_mappings: data['standalone'].append(account) return data def add_blockchain_accounts( self, blockchain: SupportedBlockchain, account_data: List[BlockchainAccountData], ) -> BlockchainBalancesUpdate: """Adds new blockchain accounts Adds the accounts to the blockchain instance and queries them to get the updated balances. Also adds them in the DB May raise: - EthSyncError from modify_blockchain_account - InputError if the given accounts list is empty. - TagConstraintError if any of the given account data contain unknown tags. - RemoteError if an external service such as Etherscan is queried and there is a problem with its query. """ self.data.db.ensure_tags_exist( given_data=account_data, action='adding', data_type='blockchain accounts', ) address_type = blockchain.get_address_type() updated_balances = self.chain_manager.add_blockchain_accounts( blockchain=blockchain, accounts=[address_type(entry.address) for entry in account_data], ) self.data.db.add_blockchain_accounts( blockchain=blockchain, account_data=account_data, ) return updated_balances def edit_blockchain_accounts( self, blockchain: SupportedBlockchain, account_data: List[BlockchainAccountData], ) -> None: """Edits blockchain accounts Edits blockchain account data for the given accounts May raise: - InputError if the given accounts list is empty or if any of the accounts to edit do not exist. - TagConstraintError if any of the given account data contain unknown tags. """ # First check for validity of account data addresses if len(account_data) == 0: raise InputError( 'Empty list of blockchain account data to edit was given') accounts = [x.address for x in account_data] unknown_accounts = set(accounts).difference( self.chain_manager.accounts.get(blockchain)) if len(unknown_accounts) != 0: raise InputError( f'Tried to edit unknown {blockchain.value} ' f'accounts {",".join(unknown_accounts)}', ) self.data.db.ensure_tags_exist( given_data=account_data, action='editing', data_type='blockchain accounts', ) # Finally edit the accounts self.data.db.edit_blockchain_accounts( blockchain=blockchain, account_data=account_data, ) return None def remove_blockchain_accounts( self, blockchain: SupportedBlockchain, accounts: ListOfBlockchainAddresses, ) -> BlockchainBalancesUpdate: """Removes blockchain accounts Removes the accounts from the blockchain instance and queries them to get the updated balances. Also removes them from the DB May raise: - RemoteError if an external service such as Etherscan is queried and there is a problem with its query. - InputError if a non-existing account was given to remove """ balances_update = self.chain_manager.remove_blockchain_accounts( blockchain=blockchain, accounts=accounts, ) self.data.db.remove_blockchain_accounts(blockchain, accounts) return balances_update def process_history( self, start_ts: Timestamp, end_ts: Timestamp, ) -> Tuple[Dict[str, Any], str]: ( error_or_empty, history, loan_history, asset_movements, eth_transactions, defi_events, ) = self.trades_historian.get_history( start_ts=start_ts, end_ts=end_ts, has_premium=self.premium is not None, ) result = self.accountant.process_history( start_ts=start_ts, end_ts=end_ts, trade_history=history, loan_history=loan_history, asset_movements=asset_movements, eth_transactions=eth_transactions, defi_events=defi_events, ) return result, error_or_empty @overload def _apply_actions_limit( self, location: Location, action_type: Literal['trade'], location_actions: List[Trade], all_actions: List[Trade], ) -> List[Trade]: ... @overload def _apply_actions_limit( self, location: Location, action_type: Literal['asset_movement'], location_actions: List[AssetMovement], all_actions: List[AssetMovement], ) -> List[AssetMovement]: ... def _apply_actions_limit( self, location: Location, action_type: Literal['trade', 'asset_movement'], location_actions: Union[List[Trade], List[AssetMovement]], all_actions: Union[List[Trade], List[AssetMovement]], ) -> Union[List[Trade], List[AssetMovement]]: """Take as many actions from location actions and add them to all actions as the limit permits Returns the modified (or not) all_actions """ # If we are already at or above the limit return current actions disregarding this location actions_mapping = self.actions_per_location[action_type] current_num_actions = sum(x for _, x in actions_mapping.items()) limit = LIMITS_MAPPING[action_type] if current_num_actions >= limit: return all_actions # Find out how many more actions can we return, and depending on that get # the number of actions from the location actions and add them to the total remaining_num_actions = limit - current_num_actions if remaining_num_actions < 0: remaining_num_actions = 0 num_actions_to_take = min(len(location_actions), remaining_num_actions) actions_mapping[location] = num_actions_to_take all_actions.extend( location_actions[0:num_actions_to_take]) # type: ignore return all_actions def query_trades( self, from_ts: Timestamp, to_ts: Timestamp, location: Optional[Location], ) -> List[Trade]: """Queries trades for the given location and time range. If no location is given then all external and all exchange trades are queried. If the user does not have premium then a trade limit is applied. May raise: - RemoteError: If there are problems connectingto any of the remote exchanges """ if location is not None: trades = self.query_location_trades(from_ts, to_ts, location) else: trades = self.query_location_trades(from_ts, to_ts, Location.EXTERNAL) for name, exchange in self.exchange_manager.connected_exchanges.items( ): exchange_trades = exchange.query_trade_history( start_ts=from_ts, end_ts=to_ts) if self.premium is None: trades = self._apply_actions_limit( location=deserialize_location(name), action_type='trade', location_actions=exchange_trades, all_actions=trades, ) else: trades.extend(exchange_trades) # return trades with most recent first trades.sort(key=lambda x: x.timestamp, reverse=True) return trades def query_location_trades( self, from_ts: Timestamp, to_ts: Timestamp, location: Location, ) -> List[Trade]: # clear the trades queried for this location self.actions_per_location['trade'][location] = 0 if location == Location.EXTERNAL: location_trades = self.data.db.get_trades( from_ts=from_ts, to_ts=to_ts, location=location, ) else: # should only be an exchange exchange = self.exchange_manager.get(str(location)) if not exchange: logger.warn( f'Tried to query trades from {location} which is either not an ' f'exchange or not an exchange the user has connected to', ) return [] location_trades = exchange.query_trade_history(start_ts=from_ts, end_ts=to_ts) trades: List[Trade] = [] if self.premium is None: trades = self._apply_actions_limit( location=location, action_type='trade', location_actions=location_trades, all_actions=trades, ) else: trades = location_trades return trades def query_balances( self, requested_save_data: bool = False, timestamp: Timestamp = None, ignore_cache: bool = False, ) -> Dict[str, Any]: """Query all balances rotkehlchen can see. If requested_save_data is True then the data are always saved in the DB, if it is False then data are saved if self.data.should_save_balances() is True. If timestamp is None then the current timestamp is used. If a timestamp is given then that is the time that the balances are going to be saved in the DB If ignore_cache is True then all underlying calls that have a cache ignore it Returns a dictionary with the queried balances. """ log.info('query_balances called', requested_save_data=requested_save_data) balances = {} problem_free = True for _, exchange in self.exchange_manager.connected_exchanges.items(): exchange_balances, _ = exchange.query_balances( ignore_cache=ignore_cache) # If we got an error, disregard that exchange but make sure we don't save data if not isinstance(exchange_balances, dict): problem_free = False else: balances[exchange.name] = exchange_balances try: blockchain_result = self.chain_manager.query_balances( blockchain=None, force_token_detection=ignore_cache, ignore_cache=ignore_cache, ) balances['blockchain'] = { asset: balance.to_dict() for asset, balance in blockchain_result.totals.items() } except (RemoteError, EthSyncError) as e: problem_free = False log.error(f'Querying blockchain balances failed due to: {str(e)}') balances = account_for_manually_tracked_balances(db=self.data.db, balances=balances) combined = combine_stat_dicts([v for k, v in balances.items()]) total_usd_per_location = [(k, dict_get_sumof(v, 'usd_value')) for k, v in balances.items()] # calculate net usd value net_usd = FVal(0) for _, v in combined.items(): net_usd += FVal(v['usd_value']) stats: Dict[str, Any] = { 'location': {}, 'net_usd': net_usd, } for entry in total_usd_per_location: name = entry[0] total = entry[1] if net_usd != FVal(0): percentage = (total / net_usd).to_percentage() else: percentage = '0%' stats['location'][name] = { 'usd_value': total, 'percentage_of_net_value': percentage, } for k, v in combined.items(): if net_usd != FVal(0): percentage = (v['usd_value'] / net_usd).to_percentage() else: percentage = '0%' combined[k]['percentage_of_net_value'] = percentage result_dict = merge_dicts(combined, stats) allowed_to_save = requested_save_data or self.data.should_save_balances( ) if problem_free and allowed_to_save: if not timestamp: timestamp = Timestamp(int(time.time())) self.data.save_balances_data(data=result_dict, timestamp=timestamp) log.debug('query_balances data saved') else: log.debug( 'query_balances data not saved', allowed_to_save=allowed_to_save, problem_free=problem_free, ) # After adding it to the saved file we can overlay additional data that # is not required to be saved in the history file try: details = self.accountant.events.details for asset, (tax_free_amount, average_buy_value) in details.items(): if asset not in result_dict: continue result_dict[asset]['tax_free_amount'] = tax_free_amount result_dict[asset]['average_buy_value'] = average_buy_value current_price = result_dict[asset]['usd_value'] / result_dict[ asset]['amount'] if average_buy_value != FVal(0): result_dict[asset]['percent_change'] = ( ((current_price - average_buy_value) / average_buy_value) * 100) else: result_dict[asset]['percent_change'] = 'INF' except AttributeError: pass return result_dict def _query_exchange_asset_movements( self, from_ts: Timestamp, to_ts: Timestamp, all_movements: List[AssetMovement], exchange: ExchangeInterface, ) -> List[AssetMovement]: location = deserialize_location(exchange.name) # clear the asset movements queried for this exchange self.actions_per_location['asset_movement'][location] = 0 location_movements = exchange.query_deposits_withdrawals( start_ts=from_ts, end_ts=to_ts) movements: List[AssetMovement] = [] if self.premium is None: movements = self._apply_actions_limit( location=location, action_type='asset_movement', location_actions=location_movements, all_actions=all_movements, ) else: movements = location_movements return movements def query_asset_movements( self, from_ts: Timestamp, to_ts: Timestamp, location: Optional[Location], ) -> List[AssetMovement]: """Queries AssetMovements for the given location and time range. If no location is given then all exchange asset movements are queried. If the user does not have premium then a limit is applied. May raise: - RemoteError: If there are problems connecting to any of the remote exchanges """ movements: List[AssetMovement] = [] if location is not None: exchange = self.exchange_manager.get(str(location)) if not exchange: logger.warn( f'Tried to query deposits/withdrawals from {location} which is either not an ' f'exchange or not an exchange the user has connected to', ) return [] movements = self._query_exchange_asset_movements( from_ts=from_ts, to_ts=to_ts, all_movements=movements, exchange=exchange, ) else: for _, exchange in self.exchange_manager.connected_exchanges.items( ): movements = self._query_exchange_asset_movements( from_ts=from_ts, to_ts=to_ts, all_movements=movements, exchange=exchange, ) # return movements with most recent first movements.sort(key=lambda x: x.timestamp, reverse=True) return movements def set_settings(self, settings: ModifiableDBSettings) -> Tuple[bool, str]: """Tries to set new settings. Returns True in success or False with message if error""" with self.lock: if settings.eth_rpc_endpoint is not None: result, msg = self.chain_manager.set_eth_rpc_endpoint( settings.eth_rpc_endpoint) if not result: return False, msg if settings.kraken_account_type is not None: kraken = self.exchange_manager.get('kraken') if kraken: kraken.set_account_type( settings.kraken_account_type) # type: ignore self.data.db.set_settings(settings) return True, '' def get_settings(self) -> DBSettings: """Returns the db settings with a check whether premium is active or not""" db_settings = self.data.db.get_settings( have_premium=self.premium is not None) return db_settings def setup_exchange( self, name: str, api_key: ApiKey, api_secret: ApiSecret, passphrase: Optional[str] = None, ) -> Tuple[bool, str]: """ Setup a new exchange with an api key and an api secret and optionally a passphrase By default the api keys are always validated unless validate is False. """ is_success, msg = self.exchange_manager.setup_exchange( name=name, api_key=api_key, api_secret=api_secret, database=self.data.db, passphrase=passphrase, ) if is_success: # Success, save the result in the DB self.data.db.add_exchange(name, api_key, api_secret, passphrase=passphrase) return is_success, msg def remove_exchange(self, name: str) -> Tuple[bool, str]: if not self.exchange_manager.has_exchange(name): return False, 'Exchange {} is not registered'.format(name) self.exchange_manager.delete_exchange(name) # Success, remove it also from the DB self.data.db.remove_exchange(name) self.data.db.delete_used_query_range_for_exchange(name) return True, '' def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]: """Query for frequently changing data""" result: Dict[str, Union[bool, Timestamp]] = {} if self.user_is_logged_in: result[ 'last_balance_save'] = self.data.db.get_last_balance_save_time( ) result[ 'eth_node_connection'] = self.chain_manager.ethereum.web3_mapping.get( NodeName.OWN, None) is not None # noqa : E501 result[ 'history_process_start_ts'] = self.accountant.started_processing_timestamp result[ 'history_process_current_ts'] = self.accountant.currently_processing_timestamp result['last_data_upload_ts'] = Timestamp( self.premium_sync_manager.last_data_upload_ts) # noqa : E501 return result def shutdown(self) -> None: self.logout() self.shutdown_event.set()
class Rotkehlchen(): def __init__(self, args: argparse.Namespace) -> None: """Initialize the Rotkehlchen object May Raise: - SystemPermissionError if the given data directory's permissions are not correct. """ self.lock = Semaphore() self.lock.acquire() # Can also be None after unlock if premium credentials did not # authenticate or premium server temporarily offline self.premium: Optional[Premium] = None self.user_is_logged_in: bool = False configure_logging(args) self.sleep_secs = args.sleep_secs if args.data_dir is None: self.data_dir = default_data_directory() else: self.data_dir = Path(args.data_dir) if not os.access(self.data_dir, os.W_OK | os.R_OK): raise SystemPermissionError( f'The given data directory {self.data_dir} is not readable or writable', ) self.main_loop_spawned = False self.args = args self.api_task_greenlets: List[gevent.Greenlet] = [] self.msg_aggregator = MessagesAggregator() self.greenlet_manager = GreenletManager( msg_aggregator=self.msg_aggregator) self.exchange_manager = ExchangeManager( msg_aggregator=self.msg_aggregator) # Initialize the AssetResolver singleton AssetResolver(data_directory=self.data_dir) self.data = DataHandler(self.data_dir, self.msg_aggregator) self.cryptocompare = Cryptocompare(data_directory=self.data_dir, database=None) self.coingecko = Coingecko(data_directory=self.data_dir) self.icon_manager = IconManager(data_dir=self.data_dir, coingecko=self.coingecko) self.greenlet_manager.spawn_and_track( after_seconds=None, task_name='periodically_query_icons_until_all_cached', exception_is_error=False, method=self.icon_manager.periodically_query_icons_until_all_cached, batch_size=ICONS_BATCH_SIZE, sleep_time_secs=ICONS_QUERY_SLEEP, ) # Initialize the Inquirer singleton Inquirer( data_dir=self.data_dir, cryptocompare=self.cryptocompare, coingecko=self.coingecko, ) # Keeps how many trades we have found per location. Used for free user limiting self.actions_per_location: Dict[str, Dict[Location, int]] = { 'trade': defaultdict(int), 'asset_movement': defaultdict(int), } self.lock.release() self.task_manager: Optional[TaskManager] = None self.shutdown_event = gevent.event.Event() def reset_after_failed_account_creation_or_login(self) -> None: """If the account creation or login failed make sure that the Rotki instance is clear Tricky instances are when after either failed premium credentials or user refusal to sync premium databases we relogged in. """ self.cryptocompare.db = None def unlock_user( self, user: str, password: str, create_new: bool, sync_approval: Literal['yes', 'no', 'unknown'], premium_credentials: Optional[PremiumCredentials], initial_settings: Optional[ModifiableDBSettings] = None, ) -> None: """Unlocks an existing user or creates a new one if `create_new` is True May raise: - PremiumAuthenticationError if the password can't unlock the database. - AuthenticationError if premium_credentials are given and are invalid or can't authenticate with the server - DBUpgradeError if the rotki DB version is newer than the software or there is a DB upgrade and there is an error. - SystemPermissionError if the directory or DB file can not be accessed """ log.info( 'Unlocking user', user=user, create_new=create_new, sync_approval=sync_approval, initial_settings=initial_settings, ) # unlock or create the DB self.password = password self.user_directory = self.data.unlock(user, password, create_new, initial_settings) self.data_importer = DataImporter(db=self.data.db) self.last_data_upload_ts = self.data.db.get_last_data_upload_ts() self.premium_sync_manager = PremiumSyncManager(data=self.data, password=password) # set the DB in the external services instances that need it self.cryptocompare.set_database(self.data.db) # Anything that was set above here has to be cleaned in case of failure in the next step # by reset_after_failed_account_creation_or_login() try: self.premium = self.premium_sync_manager.try_premium_at_start( given_premium_credentials=premium_credentials, username=user, create_new=create_new, sync_approval=sync_approval, ) except PremiumAuthenticationError: # Reraise it only if this is during the creation of a new account where # the premium credentials were given by the user if create_new: raise self.msg_aggregator.add_warning( 'Could not authenticate the Rotki premium API keys found in the DB.' ' Has your subscription expired?', ) # else let's just continue. User signed in succesfully, but he just # has unauthenticable/invalid premium credentials remaining in his DB settings = self.get_settings() self.greenlet_manager.spawn_and_track( after_seconds=None, task_name='submit_usage_analytics', exception_is_error=False, method=maybe_submit_usage_analytics, should_submit=settings.submit_usage_analytics, ) self.etherscan = Etherscan(database=self.data.db, msg_aggregator=self.msg_aggregator) self.beaconchain = BeaconChain(database=self.data.db, msg_aggregator=self.msg_aggregator) eth_rpc_endpoint = settings.eth_rpc_endpoint # Initialize the price historian singleton PriceHistorian( data_directory=self.data_dir, cryptocompare=self.cryptocompare, coingecko=self.coingecko, ) PriceHistorian().set_oracles_order(settings.historical_price_oracles) self.accountant = Accountant( db=self.data.db, user_directory=self.user_directory, msg_aggregator=self.msg_aggregator, create_csv=True, premium=self.premium, ) # Initialize the rotkehlchen logger LoggingSettings(anonymized_logs=settings.anonymized_logs) exchange_credentials = self.data.db.get_exchange_credentials() self.exchange_manager.initialize_exchanges( exchange_credentials=exchange_credentials, database=self.data.db, ) # Initialize blockchain querying modules ethereum_manager = EthereumManager( ethrpc_endpoint=eth_rpc_endpoint, etherscan=self.etherscan, database=self.data.db, msg_aggregator=self.msg_aggregator, greenlet_manager=self.greenlet_manager, connect_at_start=ETHEREUM_NODES_TO_CONNECT_AT_START, ) kusama_manager = SubstrateManager( chain=SubstrateChain.KUSAMA, msg_aggregator=self.msg_aggregator, greenlet_manager=self.greenlet_manager, connect_at_start=KUSAMA_NODES_TO_CONNECT_AT_START, connect_on_startup=self._connect_ksm_manager_on_startup(), own_rpc_endpoint=settings.ksm_rpc_endpoint, ) Inquirer().inject_ethereum(ethereum_manager) Inquirer().set_oracles_order(settings.current_price_oracles) self.chain_manager = ChainManager( blockchain_accounts=self.data.db.get_blockchain_accounts(), ethereum_manager=ethereum_manager, kusama_manager=kusama_manager, msg_aggregator=self.msg_aggregator, database=self.data.db, greenlet_manager=self.greenlet_manager, premium=self.premium, eth_modules=settings.active_modules, data_directory=self.data_dir, beaconchain=self.beaconchain, btc_derivation_gap_limit=settings.btc_derivation_gap_limit, ) self.events_historian = EventsHistorian( user_directory=self.user_directory, db=self.data.db, msg_aggregator=self.msg_aggregator, exchange_manager=self.exchange_manager, chain_manager=self.chain_manager, ) self.task_manager = TaskManager( max_tasks_num=DEFAULT_MAX_TASKS_NUM, greenlet_manager=self.greenlet_manager, api_task_greenlets=self.api_task_greenlets, database=self.data.db, cryptocompare=self.cryptocompare, premium_sync_manager=self.premium_sync_manager, chain_manager=self.chain_manager, exchange_manager=self.exchange_manager, ) self.user_is_logged_in = True log.debug('User unlocking complete') def logout(self) -> None: if not self.user_is_logged_in: return user = self.data.username log.info( 'Logging out user', user=user, ) self.deactivate_premium_status() self.greenlet_manager.clear() del self.chain_manager self.exchange_manager.delete_all_exchanges() # Reset rotkehlchen logger to default LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS) del self.accountant del self.events_historian del self.data_importer self.data.logout() self.password = '' self.cryptocompare.unset_database() # Make sure no messages leak to other user sessions self.msg_aggregator.consume_errors() self.msg_aggregator.consume_warnings() self.task_manager = None self.user_is_logged_in = False log.info( 'User successfully logged out', user=user, ) def set_premium_credentials(self, credentials: PremiumCredentials) -> None: """ Sets the premium credentials for Rotki Raises PremiumAuthenticationError if the given key is rejected by the Rotkehlchen server """ log.info('Setting new premium credentials') if self.premium is not None: self.premium.set_credentials(credentials) else: self.premium = premium_create_and_verify(credentials) self.premium_sync_manager.premium = self.premium self.accountant.premium = self.premium self.data.db.set_rotkehlchen_premium(credentials) def delete_premium_credentials(self) -> Tuple[bool, str]: """Deletes the premium credentials for Rotki""" msg = '' success = self.data.db.del_rotkehlchen_premium() if success is False: msg = 'The database was unable to delete the Premium keys for the logged-in user' self.deactivate_premium_status() return success, msg def deactivate_premium_status(self) -> None: """Deactivate premium in the current session""" self.premium = None self.premium_sync_manager.premium = None self.chain_manager.deactivate_premium_status() self.accountant.deactivate_premium_status() def start(self) -> gevent.Greenlet: assert not self.main_loop_spawned, 'Tried to spawn the main loop twice' greenlet = gevent.spawn(self.main_loop) self.main_loop_spawned = True return greenlet def main_loop(self) -> None: """Rotki main loop that fires often and runs the task manager's scheduler""" while self.shutdown_event.wait( timeout=MAIN_LOOP_SECS_DELAY) is not True: if self.task_manager is not None: self.task_manager.schedule() def get_blockchain_account_data( self, blockchain: SupportedBlockchain, ) -> Union[List[BlockchainAccountData], Dict[str, Any]]: account_data = self.data.db.get_blockchain_account_data(blockchain) if blockchain != SupportedBlockchain.BITCOIN: return account_data xpub_data = self.data.db.get_bitcoin_xpub_data() addresses_to_account_data = {x.address: x for x in account_data} address_to_xpub_mappings = self.data.db.get_addresses_to_xpub_mapping( list(addresses_to_account_data.keys()), # type: ignore ) xpub_mappings: Dict['XpubData', List[BlockchainAccountData]] = {} for address, xpub_entry in address_to_xpub_mappings.items(): if xpub_entry not in xpub_mappings: xpub_mappings[xpub_entry] = [] xpub_mappings[xpub_entry].append( addresses_to_account_data[address]) data: Dict[str, Any] = {'standalone': [], 'xpubs': []} # Add xpub data for xpub_entry in xpub_data: data_entry = xpub_entry.serialize() addresses = xpub_mappings.get(xpub_entry, None) data_entry['addresses'] = addresses if addresses and len( addresses) != 0 else None data['xpubs'].append(data_entry) # Add standalone addresses for account in account_data: if account.address not in address_to_xpub_mappings: data['standalone'].append(account) return data def add_blockchain_accounts( self, blockchain: SupportedBlockchain, account_data: List[BlockchainAccountData], ) -> BlockchainBalancesUpdate: """Adds new blockchain accounts Adds the accounts to the blockchain instance and queries them to get the updated balances. Also adds them in the DB May raise: - EthSyncError from modify_blockchain_account - InputError if the given accounts list is empty. - TagConstraintError if any of the given account data contain unknown tags. - RemoteError if an external service such as Etherscan is queried and there is a problem with its query. """ self.data.db.ensure_tags_exist( given_data=account_data, action='adding', data_type='blockchain accounts', ) address_type = blockchain.get_address_type() updated_balances = self.chain_manager.add_blockchain_accounts( blockchain=blockchain, accounts=[address_type(entry.address) for entry in account_data], ) self.data.db.add_blockchain_accounts( blockchain=blockchain, account_data=account_data, ) return updated_balances def edit_blockchain_accounts( self, blockchain: SupportedBlockchain, account_data: List[BlockchainAccountData], ) -> None: """Edits blockchain accounts Edits blockchain account data for the given accounts May raise: - InputError if the given accounts list is empty or if any of the accounts to edit do not exist. - TagConstraintError if any of the given account data contain unknown tags. """ # First check for validity of account data addresses if len(account_data) == 0: raise InputError( 'Empty list of blockchain account data to edit was given') accounts = [x.address for x in account_data] unknown_accounts = set(accounts).difference( self.chain_manager.accounts.get(blockchain)) if len(unknown_accounts) != 0: raise InputError( f'Tried to edit unknown {blockchain.value} ' f'accounts {",".join(unknown_accounts)}', ) self.data.db.ensure_tags_exist( given_data=account_data, action='editing', data_type='blockchain accounts', ) # Finally edit the accounts self.data.db.edit_blockchain_accounts( blockchain=blockchain, account_data=account_data, ) def remove_blockchain_accounts( self, blockchain: SupportedBlockchain, accounts: ListOfBlockchainAddresses, ) -> BlockchainBalancesUpdate: """Removes blockchain accounts Removes the accounts from the blockchain instance and queries them to get the updated balances. Also removes them from the DB May raise: - RemoteError if an external service such as Etherscan is queried and there is a problem with its query. - InputError if a non-existing account was given to remove """ balances_update = self.chain_manager.remove_blockchain_accounts( blockchain=blockchain, accounts=accounts, ) self.data.db.remove_blockchain_accounts(blockchain, accounts) return balances_update def get_history_query_status(self) -> Dict[str, str]: if self.events_historian.progress < FVal('100'): processing_state = self.events_historian.processing_state_name progress = self.events_historian.progress / 2 elif self.accountant.first_processed_timestamp == -1: processing_state = 'Processing all retrieved historical events' progress = FVal(50) else: processing_state = 'Processing all retrieved historical events' # start_ts is min of the query start or the first action timestamp since action # processing can start well before query start to calculate cost basis start_ts = min( self.accountant.events.query_start_ts, self.accountant.first_processed_timestamp, ) diff = self.accountant.events.query_end_ts - start_ts progress = 50 + 100 * ( FVal(self.accountant.currently_processing_timestamp - start_ts) / FVal(diff) / 2) return { 'processing_state': str(processing_state), 'total_progress': str(progress) } def process_history( self, start_ts: Timestamp, end_ts: Timestamp, ) -> Tuple[Dict[str, Any], str]: ( error_or_empty, history, loan_history, asset_movements, eth_transactions, defi_events, ledger_actions, ) = self.events_historian.get_history( start_ts=start_ts, end_ts=end_ts, has_premium=self.premium is not None, ) result = self.accountant.process_history( start_ts=start_ts, end_ts=end_ts, trade_history=history, loan_history=loan_history, asset_movements=asset_movements, eth_transactions=eth_transactions, defi_events=defi_events, ledger_actions=ledger_actions, ) return result, error_or_empty @overload def _apply_actions_limit( self, location: Location, action_type: Literal['trade'], location_actions: TRADES_LIST, all_actions: TRADES_LIST, ) -> TRADES_LIST: ... @overload def _apply_actions_limit( self, location: Location, action_type: Literal['asset_movement'], location_actions: List[AssetMovement], all_actions: List[AssetMovement], ) -> List[AssetMovement]: ... def _apply_actions_limit( self, location: Location, action_type: Literal['trade', 'asset_movement'], location_actions: Union[TRADES_LIST, List[AssetMovement]], all_actions: Union[TRADES_LIST, List[AssetMovement]], ) -> Union[TRADES_LIST, List[AssetMovement]]: """Take as many actions from location actions and add them to all actions as the limit permits Returns the modified (or not) all_actions """ # If we are already at or above the limit return current actions disregarding this location actions_mapping = self.actions_per_location[action_type] current_num_actions = sum(x for _, x in actions_mapping.items()) limit = LIMITS_MAPPING[action_type] if current_num_actions >= limit: return all_actions # Find out how many more actions can we return, and depending on that get # the number of actions from the location actions and add them to the total remaining_num_actions = limit - current_num_actions if remaining_num_actions < 0: remaining_num_actions = 0 num_actions_to_take = min(len(location_actions), remaining_num_actions) actions_mapping[location] = num_actions_to_take all_actions.extend( location_actions[0:num_actions_to_take]) # type: ignore return all_actions def query_trades( self, from_ts: Timestamp, to_ts: Timestamp, location: Optional[Location], ) -> TRADES_LIST: """Queries trades for the given location and time range. If no location is given then all external, all exchange and DEX trades are queried. DEX Trades are queried only if the user has premium If the user does not have premium then a trade limit is applied. May raise: - RemoteError: If there are problems connecting to any of the remote exchanges """ trades: TRADES_LIST if location is not None: trades = self.query_location_trades(from_ts, to_ts, location) else: trades = self.query_location_trades(from_ts, to_ts, Location.EXTERNAL) # crypto.com is not an API key supported exchange but user can import from CSV trades.extend( self.query_location_trades(from_ts, to_ts, Location.CRYPTOCOM)) for name, exchange in self.exchange_manager.connected_exchanges.items( ): exchange_trades = exchange.query_trade_history( start_ts=from_ts, end_ts=to_ts) if self.premium is None: trades = self._apply_actions_limit( location=deserialize_location(name), action_type='trade', location_actions=exchange_trades, all_actions=trades, ) else: trades.extend(exchange_trades) # for all trades we also need uniswap trades if self.premium is not None: uniswap = self.chain_manager.uniswap if uniswap is not None: trades.extend( uniswap.get_trades( addresses=self.chain_manager. queried_addresses_for_module('uniswap'), from_timestamp=from_ts, to_timestamp=to_ts, ), ) # return trades with most recent first trades.sort(key=lambda x: x.timestamp, reverse=True) return trades def query_location_trades( self, from_ts: Timestamp, to_ts: Timestamp, location: Location, ) -> TRADES_LIST: # clear the trades queried for this location self.actions_per_location['trade'][location] = 0 location_trades: TRADES_LIST if location in (Location.EXTERNAL, Location.CRYPTOCOM): location_trades = self.data.db.get_trades( # type: ignore # list invariance from_ts=from_ts, to_ts=to_ts, location=location, ) elif location == Location.UNISWAP: if self.premium is not None: uniswap = self.chain_manager.uniswap if uniswap is not None: location_trades = uniswap.get_trades( # type: ignore # list invariance addresses=self.chain_manager. queried_addresses_for_module('uniswap'), from_timestamp=from_ts, to_timestamp=to_ts, ) else: # should only be an exchange exchange = self.exchange_manager.get(str(location)) if not exchange: logger.warning( f'Tried to query trades from {location} which is either not an ' f'exchange or not an exchange the user has connected to', ) return [] location_trades = exchange.query_trade_history(start_ts=from_ts, end_ts=to_ts) trades: TRADES_LIST = [] if self.premium is None: trades = self._apply_actions_limit( location=location, action_type='trade', location_actions=location_trades, all_actions=trades, ) else: trades = location_trades return trades def query_balances( self, requested_save_data: bool = False, timestamp: Timestamp = None, ignore_cache: bool = False, ) -> Dict[str, Any]: """Query all balances rotkehlchen can see. If requested_save_data is True then the data are always saved in the DB, if it is False then data are saved if self.data.should_save_balances() is True. If timestamp is None then the current timestamp is used. If a timestamp is given then that is the time that the balances are going to be saved in the DB If ignore_cache is True then all underlying calls that have a cache ignore it Returns a dictionary with the queried balances. """ log.info('query_balances called', requested_save_data=requested_save_data) balances: Dict[str, Dict[Asset, Balance]] = {} problem_free = True for _, exchange in self.exchange_manager.connected_exchanges.items(): exchange_balances, _ = exchange.query_balances( ignore_cache=ignore_cache) # If we got an error, disregard that exchange but make sure we don't save data if not isinstance(exchange_balances, dict): problem_free = False else: balances[exchange.name] = exchange_balances liabilities: Dict[Asset, Balance] try: blockchain_result = self.chain_manager.query_balances( blockchain=None, force_token_detection=ignore_cache, ignore_cache=ignore_cache, ) balances[str( Location.BLOCKCHAIN)] = blockchain_result.totals.assets liabilities = blockchain_result.totals.liabilities except (RemoteError, EthSyncError) as e: problem_free = False liabilities = {} log.error(f'Querying blockchain balances failed due to: {str(e)}') balances = account_for_manually_tracked_balances(db=self.data.db, balances=balances) # Calculate usd totals assets_total_balance: DefaultDict[Asset, Balance] = defaultdict(Balance) total_usd_per_location: Dict[str, FVal] = {} for location, asset_balance in balances.items(): total_usd_per_location[location] = ZERO for asset, balance in asset_balance.items(): assets_total_balance[asset] += balance total_usd_per_location[location] += balance.usd_value net_usd = sum((balance.usd_value for _, balance in assets_total_balance.items()), ZERO) liabilities_total_usd = sum( (liability.usd_value for _, liability in liabilities.items()), ZERO) # noqa: E501 net_usd -= liabilities_total_usd # Calculate location stats location_stats: Dict[str, Any] = {} for location, total_usd in total_usd_per_location.items(): if location == str(Location.BLOCKCHAIN): total_usd -= liabilities_total_usd percentage = (total_usd / net_usd).to_percentage() if net_usd != ZERO else '0%' location_stats[location] = { 'usd_value': total_usd, 'percentage_of_net_value': percentage, } # Calculate 'percentage_of_net_value' per asset assets_total_balance_as_dict: Dict[Asset, Dict[str, Any]] = { asset: balance.to_dict() for asset, balance in assets_total_balance.items() } liabilities_as_dict: Dict[Asset, Dict[str, Any]] = { asset: balance.to_dict() for asset, balance in liabilities.items() } for asset, balance_dict in assets_total_balance_as_dict.items(): percentage = (balance_dict['usd_value'] / net_usd).to_percentage( ) if net_usd != ZERO else '0%' # noqa: E501 assets_total_balance_as_dict[asset][ 'percentage_of_net_value'] = percentage for asset, balance_dict in liabilities_as_dict.items(): percentage = (balance_dict['usd_value'] / net_usd).to_percentage( ) if net_usd != ZERO else '0%' # noqa: E501 liabilities_as_dict[asset]['percentage_of_net_value'] = percentage # Compose balances response result_dict = { 'assets': assets_total_balance_as_dict, 'liabilities': liabilities_as_dict, 'location': location_stats, 'net_usd': net_usd, } allowed_to_save = requested_save_data or self.data.should_save_balances( ) if problem_free and allowed_to_save: if not timestamp: timestamp = Timestamp(int(time.time())) self.data.db.save_balances_data(data=result_dict, timestamp=timestamp) log.debug('query_balances data saved') else: log.debug( 'query_balances data not saved', allowed_to_save=allowed_to_save, problem_free=problem_free, ) return result_dict def _query_exchange_asset_movements( self, from_ts: Timestamp, to_ts: Timestamp, all_movements: List[AssetMovement], exchange: Union[ExchangeInterface, Location], ) -> List[AssetMovement]: if isinstance(exchange, ExchangeInterface): location = deserialize_location(exchange.name) # clear the asset movements queried for this exchange self.actions_per_location['asset_movement'][location] = 0 location_movements = exchange.query_deposits_withdrawals( start_ts=from_ts, end_ts=to_ts, ) else: assert isinstance(exchange, Location), 'only a location should make it here' assert exchange == Location.CRYPTOCOM, 'only cryptocom should make it here' location = exchange # cryptocom has no exchange integration but we may have DB entries self.actions_per_location['asset_movement'][location] = 0 location_movements = self.data.db.get_asset_movements( from_ts=from_ts, to_ts=to_ts, location=location, ) movements: List[AssetMovement] = [] if self.premium is None: movements = self._apply_actions_limit( location=location, action_type='asset_movement', location_actions=location_movements, all_actions=all_movements, ) else: all_movements.extend(location_movements) movements = all_movements return movements def query_asset_movements( self, from_ts: Timestamp, to_ts: Timestamp, location: Optional[Location], ) -> List[AssetMovement]: """Queries AssetMovements for the given location and time range. If no location is given then all exchange asset movements are queried. If the user does not have premium then a limit is applied. May raise: - RemoteError: If there are problems connecting to any of the remote exchanges """ movements: List[AssetMovement] = [] if location is not None: if location == Location.CRYPTOCOM: movements = self._query_exchange_asset_movements( from_ts=from_ts, to_ts=to_ts, all_movements=movements, exchange=Location.CRYPTOCOM, ) else: exchange = self.exchange_manager.get(str(location)) if not exchange: logger.warning( f'Tried to query deposits/withdrawals from {location} which is either ' f'not at exchange or not an exchange the user has connected to', ) return [] movements = self._query_exchange_asset_movements( from_ts=from_ts, to_ts=to_ts, all_movements=movements, exchange=exchange, ) else: # cryptocom has no exchange integration but we may have DB entries due to csv import movements = self._query_exchange_asset_movements( from_ts=from_ts, to_ts=to_ts, all_movements=movements, exchange=Location.CRYPTOCOM, ) for _, exchange in self.exchange_manager.connected_exchanges.items( ): self._query_exchange_asset_movements( from_ts=from_ts, to_ts=to_ts, all_movements=movements, exchange=exchange, ) # return movements with most recent first movements.sort(key=lambda x: x.timestamp, reverse=True) return movements def set_settings(self, settings: ModifiableDBSettings) -> Tuple[bool, str]: """Tries to set new settings. Returns True in success or False with message if error""" with self.lock: if settings.eth_rpc_endpoint is not None: result, msg = self.chain_manager.set_eth_rpc_endpoint( settings.eth_rpc_endpoint) if not result: return False, msg if settings.ksm_rpc_endpoint is not None: result, msg = self.chain_manager.set_ksm_rpc_endpoint( settings.ksm_rpc_endpoint) if not result: return False, msg if settings.kraken_account_type is not None: kraken = self.exchange_manager.get('kraken') if kraken: kraken.set_account_type( settings.kraken_account_type) # type: ignore if settings.btc_derivation_gap_limit is not None: self.chain_manager.btc_derivation_gap_limit = settings.btc_derivation_gap_limit if settings.current_price_oracles is not None: Inquirer().set_oracles_order(settings.current_price_oracles) if settings.historical_price_oracles is not None: PriceHistorian().set_oracles_order( settings.historical_price_oracles) self.data.db.set_settings(settings) return True, '' def get_settings(self) -> DBSettings: """Returns the db settings with a check whether premium is active or not""" db_settings = self.data.db.get_settings( have_premium=self.premium is not None) return db_settings def setup_exchange( self, name: str, api_key: ApiKey, api_secret: ApiSecret, passphrase: Optional[str] = None, ) -> Tuple[bool, str]: """ Setup a new exchange with an api key and an api secret and optionally a passphrase """ is_success, msg = self.exchange_manager.setup_exchange( name=name, api_key=api_key, api_secret=api_secret, database=self.data.db, passphrase=passphrase, ) if is_success: # Success, save the result in the DB self.data.db.add_exchange(name, api_key, api_secret, passphrase=passphrase) return is_success, msg def remove_exchange(self, name: str) -> Tuple[bool, str]: if not self.exchange_manager.has_exchange(name): return False, 'Exchange {} is not registered'.format(name) self.exchange_manager.delete_exchange(name) # Success, remove it also from the DB self.data.db.remove_exchange(name) self.data.db.delete_used_query_range_for_exchange(name) return True, '' def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]: """Query for frequently changing data""" result: Dict[str, Union[bool, Timestamp]] = {} if self.user_is_logged_in: result[ 'last_balance_save'] = self.data.db.get_last_balance_save_time( ) result[ 'eth_node_connection'] = self.chain_manager.ethereum.web3_mapping.get( NodeName.OWN, None) is not None # noqa : E501 result['last_data_upload_ts'] = Timestamp( self.premium_sync_manager.last_data_upload_ts) # noqa : E501 return result def shutdown(self) -> None: self.logout() self.shutdown_event.set() def _connect_ksm_manager_on_startup(self) -> bool: return bool(self.data.db.get_blockchain_accounts().ksm) def create_oracle_cache( self, oracle: HistoricalPriceOracle, from_asset: Asset, to_asset: Asset, purge_old: bool, ) -> None: """Creates the cache of the given asset pair from the start of time until now for the given oracle. if purge_old is true then any old cache in memory and in a file is purged May raise: - RemoteError if there is a problem reaching the oracle - UnsupportedAsset if any of the two assets is not supported by the oracle """ if oracle != HistoricalPriceOracle.CRYPTOCOMPARE: return # only for cryptocompare for now self.cryptocompare.create_cache(from_asset, to_asset, purge_old) def delete_oracle_cache( self, oracle: HistoricalPriceOracle, from_asset: Asset, to_asset: Asset, ) -> None: if oracle != HistoricalPriceOracle.CRYPTOCOMPARE: return # only for cryptocompare for now self.cryptocompare.delete_cache(from_asset, to_asset) def get_oracle_cache( self, oracle: HistoricalPriceOracle) -> List[Dict[str, Any]]: if oracle != HistoricalPriceOracle.CRYPTOCOMPARE: return [] # only for cryptocompare for now return self.cryptocompare.get_all_cache_data()
class Lock1: _gevent_locks = {} _gevent_lock = Semaphore() @staticmethod def AddCurrent(lock): # current = gevent.getcurrent() # Lock1._gevent_lock.acquire() # if not Lock1._gevent_locks.has_key(current): # Lock1._gevent_locks[current] = set() # Lock1._gevent_locks[current].add(lock) # Lock1._gevent_lock.release() pass @staticmethod def SubCurrent(lock): # current = gevent.getcurrent() # Lock1._gevent_lock.acquire() # Lock1._gevent_locks[current].discard(lock) # Lock1._gevent_lock.release() pass @staticmethod def ClearCurrent(): # current = gevent.getcurrent() # Lock1._gevent_lock.acquire() # if Lock1._gevent_locks.has_key(current): # while True: # if len(Lock1._gevent_locks[current]) == 0: # break # locktmp = Lock1._gevent_locks[current].pop() # locktmp.ReleaseEx() # Lock1._gevent_lock.release() pass def __init__(self): self._lock = Semaphore() self._current = None self._count = 0 def Lock(self, tag=None): # if tag: # print "[Lock1]" + tag + " 请求" # current = gevent.getcurrent() # if self._current == current: # self._count += 1 # return # if not self._lock.acquire(timeout=1): # raise Exception("lock error") # self._current = current # self._count = 1 # Lock1.AddCurrent(self) # # if tag: # print "[Lock1]" + tag + " 确认" pass def Release(self, tag=None): # self._count -= 1 # if self._count == 0: # self._lock.release() # self._current = None # Lock1.SubCurrent(self) # if tag: # print "[Lock1]" + tag + " 释放" pass def ReleaseEx(self): # self._lock.release() # self._current = None # self._count = 0 pass
import logging try: from gevent.lock import Semaphore except ImportError: from eventlet.semaphore import Semaphore from django.db.backends.mysql.base import DatabaseWrapper as OriginalDatabaseWrapper from .creation import DatabaseCreation from .connection_pool import MysqlConnectionPool logger = logging.getLogger('django.geventpool') connection_pools = {} connection_pools_lock = Semaphore(value=1) DEFAULT_MAX_CONNS = 4 class ConnectionPoolMixin(object): creation_class = DatabaseCreation def __init__(self, settings_dict, *args, **kwargs): def pop_max_conn(settings_dict): if "OPTIONS" in settings_dict: return settings_dict["OPTIONS"].pop("MAX_CONNS", DEFAULT_MAX_CONNS) else: return DEFAULT_MAX_CONNS self._pool = None settings_dict['CONN_MAX_AGE'] = 0
def __init__(self, args: argparse.Namespace) -> None: """Initialize the Rotkehlchen object May Raise: - SystemPermissionError if the given data directory's permissions are not correct. """ self.lock = Semaphore() self.lock.acquire() # Can also be None after unlock if premium credentials did not # authenticate or premium server temporarily offline self.premium: Optional[Premium] = None self.user_is_logged_in: bool = False configure_logging(args) self.sleep_secs = args.sleep_secs if args.data_dir is None: self.data_dir = default_data_directory() else: self.data_dir = Path(args.data_dir) if not os.access(self.data_dir, os.W_OK | os.R_OK): raise SystemPermissionError( f'The given data directory {self.data_dir} is not readable or writable', ) self.args = args self.msg_aggregator = MessagesAggregator() self.greenlet_manager = GreenletManager( msg_aggregator=self.msg_aggregator) self.exchange_manager = ExchangeManager( msg_aggregator=self.msg_aggregator) # Initialize the AssetResolver singleton AssetResolver(data_directory=self.data_dir) self.data = DataHandler(self.data_dir, self.msg_aggregator) self.cryptocompare = Cryptocompare(data_directory=self.data_dir, database=None) self.coingecko = Coingecko() self.icon_manager = IconManager(data_dir=self.data_dir, coingecko=self.coingecko) self.greenlet_manager.spawn_and_track( after_seconds=None, task_name='periodically_query_icons_until_all_cached', method=self.icon_manager.periodically_query_icons_until_all_cached, batch_size=ICONS_BATCH_SIZE, sleep_time_secs=ICONS_QUERY_SLEEP, ) # Initialize the Inquirer singleton Inquirer( data_dir=self.data_dir, cryptocompare=self.cryptocompare, coingecko=self.coingecko, ) # Keeps how many trades we have found per location. Used for free user limiting self.actions_per_location: Dict[str, Dict[Location, int]] = { 'trade': defaultdict(int), 'asset_movement': defaultdict(int), } self.lock.release() self.shutdown_event = gevent.event.Event()
def test_acquire_returns_false_after_timeout(self): s = Semaphore(value=0) result = s.acquire(timeout=0.01) assert result is False, repr(result)
def __init__( self, web3: Web3, privkey: bytes, gas_price_strategy: Callable = rpc_gas_price_strategy, gas_estimate_correction: Callable = lambda gas: gas, block_num_confirmations: int = 0, uses_infura=False, ): if privkey is None or len(privkey) != 32: raise ValueError('Invalid private key') if block_num_confirmations < 0: raise ValueError('Number of confirmations has to be positive', ) monkey_patch_web3(web3, gas_price_strategy) try: version = web3.version.node except ConnectTimeout: raise EthNodeCommunicationError('couldnt reach the ethereum node') _, eth_node = is_supported_client(version) address = privatekey_to_address(privkey) address_checksumed = to_checksum_address(address) if uses_infura: warnings.warn( 'Infura does not provide an API to ' 'recover the latest used nonce. This may cause the Raiden node ' 'to error on restarts.\n' 'The error will manifest while there is a pending transaction ' 'from a previous execution in the Ethereum\'s client pool. When ' 'Raiden restarts the same transaction with the same nonce will ' 'be retried and *rejected*, because the nonce is already used.', ) # The first valid nonce is 0, therefore the count is already the next # available nonce available_nonce = web3.eth.getTransactionCount( address_checksumed, 'pending') elif eth_node == constants.EthClient.PARITY: parity_assert_rpc_interfaces(web3) available_nonce = parity_discover_next_available_nonce( web3, address_checksumed, ) elif eth_node == constants.EthClient.GETH: geth_assert_rpc_interfaces(web3) available_nonce = geth_discover_next_available_nonce( web3, address_checksumed, ) else: raise EthNodeInterfaceError( f'Unsupported Ethereum client {version}') self.eth_node = eth_node self.privkey = privkey self.address = address self.web3 = web3 self.default_block_num_confirmations = block_num_confirmations self._available_nonce = available_nonce self._nonce_lock = Semaphore() self._gas_estimate_correction = gas_estimate_correction log.debug( 'JSONRPCClient created', node=pex(self.address), available_nonce=available_nonce, client=version, )
def __init__( self, chain: BlockChainService, query_start_block: BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, transport, raiden_event_handler, message_handler, config, discovery=None, ): super().__init__() self.tokennetworkids_to_connectionmanagers = dict() self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict(dict) self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_secret_registry = default_secret_registry self.config = config self.signer: Signer = LocalSigner(self.chain.client.privkey) self.address = self.signer.address self.discovery = discovery self.transport = transport self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.message_handler = message_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.wal = None self.snapshot_group = 0 # This flag will be used to prevent the service from processing # state changes events until we know that pending transactions # have been dispatched. self.dispatch_events_lock = Semaphore(1) self.contract_manager = ContractManager(config['contracts_path']) self.database_path = config['database_path'] if self.database_path != ':memory:': database_dir = os.path.dirname(config['database_path']) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Two raiden processes must not write to the same database, even # though the database itself may be consistent. If more than one # nodes writes state changes to the same WAL there are no # guarantees about recovery, this happens because during recovery # the WAL replay can not be deterministic. self.lock_file = os.path.join(self.database_dir, '.lock') self.db_lock = filelock.FileLock(self.lock_file) else: self.database_path = ':memory:' self.database_dir = None self.lock_file = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() self.gas_reserve_lock = gevent.lock.Semaphore() self.payment_identifier_lock = gevent.lock.Semaphore()
class ThreadPool(object): def __init__(self, maxsize, hub=None): if hub is None: hub = get_hub() self.hub = hub self._maxsize = 0 self.manager = None self.pid = os.getpid() self.fork_watcher = hub.loop.fork(ref=False) self._init(maxsize) def _set_maxsize(self, maxsize): if not isinstance(maxsize, integer_types): raise TypeError('maxsize must be integer: %r' % (maxsize, )) if maxsize < 0: raise ValueError('maxsize must not be negative: %r' % (maxsize, )) difference = maxsize - self._maxsize self._semaphore.counter += difference self._maxsize = maxsize self.adjust() # make sure all currently blocking spawn() start unlocking if maxsize increased self._semaphore._start_notify() def _get_maxsize(self): return self._maxsize maxsize = property(_get_maxsize, _set_maxsize) def __repr__(self): return '<%s at 0x%x %s/%s/%s>' % (self.__class__.__name__, id(self), len(self), self.size, self.maxsize) def __len__(self): # XXX just do unfinished_tasks property return self.task_queue.unfinished_tasks def _get_size(self): return self._size def _set_size(self, size): if size < 0: raise ValueError('Size of the pool cannot be negative: %r' % (size, )) if size > self._maxsize: raise ValueError( 'Size of the pool cannot be bigger than maxsize: %r > %r' % (size, self._maxsize)) if self.manager: self.manager.kill() while self._size < size: self._add_thread() delay = 0.0001 while self._size > size: while self._size - size > self.task_queue.unfinished_tasks: self.task_queue.put(None) if getcurrent() is self.hub: break sleep(delay) delay = min(delay * 2, .05) if self._size: self.fork_watcher.start(self._on_fork) else: self.fork_watcher.stop() size = property(_get_size, _set_size) def _init(self, maxsize): self._size = 0 self._semaphore = Semaphore(1) self._lock = Lock() self.task_queue = Queue() self._set_maxsize(maxsize) def _on_fork(self): # fork() only leaves one thread; also screws up locks; # let's re-create locks and threads pid = os.getpid() if pid != self.pid: self.pid = pid # Do not mix fork() and threads; since fork() only copies one thread # all objects referenced by other threads has refcount that will never # go down to 0. self._init(self._maxsize) def join(self): delay = 0.0005 while self.task_queue.unfinished_tasks > 0: sleep(delay) delay = min(delay * 2, .05) def kill(self): self.size = 0 def _adjust_step(self): # if there is a possibility & necessity for adding a thread, do it while self._size < self._maxsize and self.task_queue.unfinished_tasks > self._size: self._add_thread() # while the number of threads is more than maxsize, kill one # we do not check what's already in task_queue - it could be all Nones while self._size - self._maxsize > self.task_queue.unfinished_tasks: self.task_queue.put(None) if self._size: self.fork_watcher.start(self._on_fork) else: self.fork_watcher.stop() def _adjust_wait(self): delay = 0.0001 while True: self._adjust_step() if self._size <= self._maxsize: return sleep(delay) delay = min(delay * 2, .05) def adjust(self): self._adjust_step() if not self.manager and self._size > self._maxsize: # might need to feed more Nones into the pool self.manager = Greenlet.spawn(self._adjust_wait) def _add_thread(self): with self._lock: self._size += 1 try: start_new_thread(self._worker, ()) except: with self._lock: self._size -= 1 raise def spawn(self, func, *args, **kwargs): while True: semaphore = self._semaphore semaphore.acquire() if semaphore is self._semaphore: break try: task_queue = self.task_queue result = AsyncResult() thread_result = ThreadResult(result, hub=self.hub) task_queue.put((func, args, kwargs, thread_result)) self.adjust() # rawlink() must be the last call result.rawlink(lambda *args: self._semaphore.release()) # XXX this _semaphore.release() is competing for order with get() # XXX this is not good, just make ThreadResult release the semaphore before doing anything else except: semaphore.release() raise return result def _decrease_size(self): if sys is None: return _lock = getattr(self, '_lock', None) if _lock is not None: with _lock: self._size -= 1 def _worker(self): need_decrease = True try: while True: task_queue = self.task_queue task = task_queue.get() try: if task is None: need_decrease = False self._decrease_size() # we want first to decrease size, then decrease unfinished_tasks # otherwise, _adjust might think there's one more idle thread that # needs to be killed return func, args, kwargs, result = task try: value = func(*args, **kwargs) except: exc_info = getattr(sys, 'exc_info', None) if exc_info is None: return result.handle_error((self, func), exc_info()) else: if sys is None: return result.set(value) del value finally: del func, args, kwargs, result, task finally: if sys is None: return task_queue.task_done() finally: if need_decrease: self._decrease_size() # XXX apply() should re-raise error by default # XXX because that's what builtin apply does # XXX check gevent.pool.Pool.apply and multiprocessing.Pool.apply def apply_e(self, expected_errors, function, args=None, kwargs=None): if args is None: args = () if kwargs is None: kwargs = {} success, result = self.spawn(wrap_errors, expected_errors, function, args, kwargs).get() if success: return result raise result def apply(self, func, args=None, kwds=None): """Equivalent of the apply() builtin function. It blocks till the result is ready.""" if args is None: args = () if kwds is None: kwds = {} return self.spawn(func, *args, **kwds).get() def apply_cb(self, func, args=None, kwds=None, callback=None): result = self.apply(func, args, kwds) if callback is not None: callback(result) return result def apply_async(self, func, args=None, kwds=None, callback=None): """A variant of the apply() method which returns a Greenlet object. If callback is specified then it should be a callable which accepts a single argument. When the result becomes ready callback is applied to it (unless the call failed).""" if args is None: args = () if kwds is None: kwds = {} return Greenlet.spawn(self.apply_cb, func, args, kwds, callback) def map(self, func, iterable): return list(self.imap(func, iterable)) def map_cb(self, func, iterable, callback=None): result = self.map(func, iterable) if callback is not None: callback(result) return result def map_async(self, func, iterable, callback=None): """ A variant of the map() method which returns a Greenlet object. If callback is specified then it should be a callable which accepts a single argument. """ return Greenlet.spawn(self.map_cb, func, iterable, callback) def imap(self, func, iterable): """An equivalent of itertools.imap()""" return IMap.spawn(func, iterable, spawn=self.spawn) def imap_unordered(self, func, iterable): """The same as imap() except that the ordering of the results from the returned iterator should be considered in arbitrary order.""" return IMapUnordered.spawn(func, iterable, spawn=self.spawn)
def __init__(self,conn): self.conn = conn self.closed = False self.connerr = None self.conn_mutex = RLock() self.conn_cond = Semaphore(0)
class RpcConn(object): def __init__(self, sck, server): self._sck = sck self._rlock = Semaphore() self._wlock = Semaphore() self._server = server self._living_controllers = {} def on_reqeust(self, request): service = self.get_service(request.service_identifier) if service == None: return method = self.get_service_method(service, request.method_identifier) if method == None: return proto_request = self.get_proto_request(service, method, request) req_id = request.call_id controller = RpcController() self._living_controllers[req_id] = controller g = gevent.spawn(self.call_method_and_reply, req_id, service, method, controller, proto_request) g.start() def call_method_and_reply(self, req_id, service, method, controller, proto_request): callback = Callback() service.CallMethod(method, controller, proto_request, callback) if controller.IsCanceled(): del self._living_controllers[req_id] return payload = WirePayload() if callback.response != None and controller.Failed() == False: resp = RpcResponse() resp.response_bytes = callback.response.SerializeToString() resp.call_id = req_id payload.rpc_response.call_id = req_id payload.rpc_response.response_bytes = resp.response_bytes else: if controller.Failed(): payload.rpc_error.call_id = req_id payload.rpc_error.error = controller.ErrorText() buf = payload.SerializeToString() self._wlock.acquire() self._sck.sendall(utils.int32_to_bytes(len(buf)) + buf) self._wlock.release() # finish this call del self._living_controllers[req_id] def on_cancel(self, rpc_cancel): controller = self._living_controllers.get(rpc_cancel.call_id_to_cancel) if controller != None: controller.StartCancel() def get_service(self, service_name): service = self._server._serivces.get(service_name) return service def get_service_method(self, service, method_name): method = service.DESCRIPTOR.FindMethodByName(method_name) return method def get_proto_request(self, service, method, request): proto_request = service.GetRequestClass(method)() proto_request.ParseFromString(request.request_bytes) # Check the request parsed correctly if not proto_request.IsInitialized(): return None return proto_request def run(self): while True: # read payload package self._rlock.acquire() sz = utils.read_int32(self._sck) if sz == None: self._rlock.release() break buf = utils.readall(self._sck, sz) if buf == None: self._rlock.release() break self._rlock.release() payload = WirePayload() payload.ParseFromString(buf) # if has rpc request if payload.rpc_cancel.IsInitialized(): self.on_cancel(payload.rpc_cancel) elif payload.rpc_request.IsInitialized(): self.on_reqeust(payload.rpc_request)
draw_all() def exposed_set_battery_percent(self, pct): """ `pct` should be an integer in [0, 100]. """ with self.lock: if not isinstance(pct, int) or not (0 <= pct <= 100): raise Exception("Invalid battery percent") pct = "{}%".format(pct) global battery_sprite battery_sprite = header_font.render(pct, True, HEADER_TXT_COLOR) draw_all() from rpyc.utils.server import GeventServer from rpyc.utils.helpers import classpartial global_lock = Semaphore(value=1) ConsoleService = classpartial(ConsoleService, global_lock) rpc_server = GeventServer(ConsoleService, port=18863) log.info("RUNNING!") gevent.joinall([ gevent.spawn(rpc_server.start), ])
class Rotkehlchen(): def __init__(self, args: argparse.Namespace) -> None: self.lock = Semaphore() self.lock.acquire() self.premium = None self.user_is_logged_in = False logfilename = None if args.logtarget == 'file': logfilename = args.logfile if args.loglevel == 'debug': loglevel = logging.DEBUG elif args.loglevel == 'info': loglevel = logging.INFO elif args.loglevel == 'warn': loglevel = logging.WARN elif args.loglevel == 'error': loglevel = logging.ERROR elif args.loglevel == 'critical': loglevel = logging.CRITICAL else: raise ValueError('Should never get here. Illegal log value') logging.basicConfig( filename=logfilename, filemode='w', level=loglevel, format='%(asctime)s -- %(levelname)s:%(name)s:%(message)s', datefmt='%d/%m/%Y %H:%M:%S %Z', ) if not args.logfromothermodules: logging.getLogger('zerorpc').setLevel(logging.CRITICAL) logging.getLogger('zerorpc.channel').setLevel(logging.CRITICAL) logging.getLogger('urllib3').setLevel(logging.CRITICAL) logging.getLogger('urllib3.connectionpool').setLevel( logging.CRITICAL) self.sleep_secs = args.sleep_secs self.data_dir = args.data_dir self.args = args self.msg_aggregator = MessagesAggregator() self.exchange_manager = ExchangeManager( msg_aggregator=self.msg_aggregator) self.data = DataHandler(self.data_dir, self.msg_aggregator) # Initialize the Inquirer singleton Inquirer(data_dir=self.data_dir) self.lock.release() self.shutdown_event = gevent.event.Event() def unlock_user( self, user: str, password: str, create_new: bool, sync_approval: str, premium_credentials: Optional[PremiumCredentials], ) -> None: """Unlocks an existing user or creates a new one if `create_new` is True""" log.info( 'Unlocking user', user=user, create_new=create_new, sync_approval=sync_approval, ) # unlock or create the DB self.password = password self.user_directory = self.data.unlock(user, password, create_new) self.data_importer = DataImporter(db=self.data.db) self.last_data_upload_ts = self.data.db.get_last_data_upload_ts() self.premium_sync_manager = PremiumSyncManager(data=self.data, password=password) try: self.premium = self.premium_sync_manager.try_premium_at_start( given_premium_credentials=premium_credentials, username=user, create_new=create_new, sync_approval=sync_approval, ) except AuthenticationError: # It means that our credentials were not accepted by the server # or some other error happened pass settings = self.data.db.get_settings() maybe_submit_usage_analytics(settings.submit_usage_analytics) historical_data_start = settings.historical_data_start eth_rpc_endpoint = settings.eth_rpc_endpoint self.trades_historian = TradesHistorian( user_directory=self.user_directory, db=self.data.db, eth_accounts=self.data.get_eth_accounts(), msg_aggregator=self.msg_aggregator, exchange_manager=self.exchange_manager, ) # Initialize the price historian singleton PriceHistorian( data_directory=self.data_dir, history_date_start=historical_data_start, cryptocompare=Cryptocompare(data_directory=self.data_dir), ) db_settings = self.data.db.get_settings() self.accountant = Accountant( profit_currency=self.data.main_currency(), user_directory=self.user_directory, msg_aggregator=self.msg_aggregator, create_csv=True, ignored_assets=self.data.db.get_ignored_assets(), include_crypto2crypto=db_settings.include_crypto2crypto, taxfree_after_period=db_settings.taxfree_after_period, include_gas_costs=db_settings.include_gas_costs, ) # Initialize the rotkehlchen logger LoggingSettings(anonymized_logs=db_settings.anonymized_logs) exchange_credentials = self.data.db.get_exchange_credentials() self.exchange_manager.initialize_exchanges( exchange_credentials=exchange_credentials, database=self.data.db, ) ethchain = Ethchain(eth_rpc_endpoint) self.blockchain = Blockchain( blockchain_accounts=self.data.db.get_blockchain_accounts(), owned_eth_tokens=self.data.db.get_owned_tokens(), ethchain=ethchain, msg_aggregator=self.msg_aggregator, ) self.user_is_logged_in = True def logout(self) -> None: if not self.user_is_logged_in: return user = self.data.username log.info( 'Logging out user', user=user, ) del self.blockchain self.exchange_manager.delete_all_exchanges() # Reset rotkehlchen logger to default LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS) del self.accountant del self.trades_historian del self.data_importer if self.premium is not None: # For some reason mypy does not see that self.premium is set del self.premium # type: ignore self.data.logout() self.password = '' self.user_is_logged_in = False log.info( 'User successfully logged out', user=user, ) def set_premium_credentials(self, credentials: PremiumCredentials) -> None: """ Sets the premium credentials for Rotki Raises AuthenticationError if the given key is rejected by the Rotkehlchen server """ log.info('Setting new premium credentials') if self.premium is not None: # For some reason mypy does not see that self.premium is set self.premium.set_credentials(credentials) # type: ignore else: self.premium = premium_create_and_verify(credentials) self.data.db.set_rotkehlchen_premium(credentials) def start(self) -> gevent.Greenlet: return gevent.spawn(self.main_loop) def main_loop(self) -> None: while self.shutdown_event.wait(MAIN_LOOP_SECS_DELAY) is not True: if self.user_is_logged_in: log.debug('Main loop start') self.premium_sync_manager.maybe_upload_data_to_server() log.debug('Main loop end') def add_blockchain_account( self, blockchain: SupportedBlockchain, account: BlockchainAddress, ) -> Dict: try: new_data = self.blockchain.add_blockchain_account( blockchain, account) except (InputError, EthSyncError) as e: return simple_result(False, str(e)) self.data.add_blockchain_account(blockchain, account) return accounts_result(new_data['per_account'], new_data['totals']) def remove_blockchain_account( self, blockchain: SupportedBlockchain, account: BlockchainAddress, ) -> Dict[str, Any]: try: new_data = self.blockchain.remove_blockchain_account( blockchain, account) except (InputError, EthSyncError) as e: return simple_result(False, str(e)) self.data.remove_blockchain_account(blockchain, account) return accounts_result(new_data['per_account'], new_data['totals']) def add_owned_eth_tokens(self, tokens: List[str]) -> Dict[str, Any]: ethereum_tokens = [ EthereumToken(identifier=identifier) for identifier in tokens ] try: new_data = self.blockchain.track_new_tokens(ethereum_tokens) except (InputError, EthSyncError) as e: return simple_result(False, str(e)) self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens) return accounts_result(new_data['per_account'], new_data['totals']) def remove_owned_eth_tokens(self, tokens: List[str]) -> Dict[str, Any]: ethereum_tokens = [ EthereumToken(identifier=identifier) for identifier in tokens ] try: new_data = self.blockchain.remove_eth_tokens(ethereum_tokens) except InputError as e: return simple_result(False, str(e)) self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens) return accounts_result(new_data['per_account'], new_data['totals']) def process_history( self, start_ts: Timestamp, end_ts: Timestamp, ) -> Tuple[Dict[str, Any], str]: ( error_or_empty, history, loan_history, asset_movements, eth_transactions, ) = self.trades_historian.get_history( # For entire history processing we need to have full history available start_ts=Timestamp(0), end_ts=ts_now(), ) result = self.accountant.process_history( start_ts=start_ts, end_ts=end_ts, trade_history=history, loan_history=loan_history, asset_movements=asset_movements, eth_transactions=eth_transactions, ) return result, error_or_empty def query_fiat_balances(self) -> Dict[Asset, Dict[str, FVal]]: result = {} balances = self.data.get_fiat_balances() for currency, str_amount in balances.items(): amount = FVal(str_amount) usd_rate = Inquirer().query_fiat_pair(currency, A_USD) result[currency] = { 'amount': amount, 'usd_value': amount * usd_rate, } return result def query_balances( self, requested_save_data: bool = False, timestamp: Timestamp = None, ) -> Dict[str, Any]: """Query all balances rotkehlchen can see. If requested_save_data is True then the data are saved in the DB. If timestamp is None then the current timestamp is used. If a timestamp is given then that is the time that the balances are going to be saved in the DB Returns a dictionary with the queried balances. """ log.info('query_balances called', requested_save_data=requested_save_data) balances = {} problem_free = True for _, exchange in self.exchange_manager.connected_exchanges.items(): exchange_balances, _ = exchange.query_balances() # If we got an error, disregard that exchange but make sure we don't save data if not isinstance(exchange_balances, dict): problem_free = False else: balances[exchange.name] = exchange_balances result, error_or_empty = self.blockchain.query_balances() if error_or_empty == '': balances['blockchain'] = result['totals'] else: problem_free = False result = self.query_fiat_balances() if result != {}: balances['banks'] = result combined = combine_stat_dicts([v for k, v in balances.items()]) total_usd_per_location = [(k, dict_get_sumof(v, 'usd_value')) for k, v in balances.items()] # calculate net usd value net_usd = FVal(0) for _, v in combined.items(): net_usd += FVal(v['usd_value']) stats: Dict[str, Any] = { 'location': {}, 'net_usd': net_usd, } for entry in total_usd_per_location: name = entry[0] total = entry[1] if net_usd != FVal(0): percentage = (total / net_usd).to_percentage() else: percentage = '0%' stats['location'][name] = { 'usd_value': total, 'percentage_of_net_value': percentage, } for k, v in combined.items(): if net_usd != FVal(0): percentage = (v['usd_value'] / net_usd).to_percentage() else: percentage = '0%' combined[k]['percentage_of_net_value'] = percentage result_dict = merge_dicts(combined, stats) allowed_to_save = requested_save_data or self.data.should_save_balances( ) if problem_free and allowed_to_save: if not timestamp: timestamp = Timestamp(int(time.time())) self.data.save_balances_data(data=result_dict, timestamp=timestamp) log.debug('query_balances data saved') else: log.debug( 'query_balances data not saved', allowed_to_save=allowed_to_save, problem_free=problem_free, ) # After adding it to the saved file we can overlay additional data that # is not required to be saved in the history file try: details = self.accountant.events.details for asset, (tax_free_amount, average_buy_value) in details.items(): if asset not in result_dict: continue result_dict[asset]['tax_free_amount'] = tax_free_amount result_dict[asset]['average_buy_value'] = average_buy_value current_price = result_dict[asset]['usd_value'] / result_dict[ asset]['amount'] if average_buy_value != FVal(0): result_dict[asset]['percent_change'] = ( ((current_price - average_buy_value) / average_buy_value) * 100) else: result_dict[asset]['percent_change'] = 'INF' except AttributeError: pass return result_dict def set_main_currency(self, currency_string: str) -> Tuple[bool, str]: """Takes a currency string from the API and sets it as the main currency for rotki Returns True and empty string for success and False and error string for error """ try: currency = Asset(currency_string) except UnknownAsset: msg = f'An unknown asset {currency_string} was given for main currency' log.critical(msg) return False, msg if not currency.is_fiat(): msg = f'A non-fiat asset {currency_string} was given for main currency' log.critical(msg) return False, msg fiat_currency = FiatAsset(currency.identifier) with self.lock: self.data.set_main_currency(fiat_currency, self.accountant) return True, '' def set_settings(self, settings: Dict[str, Any]) -> Tuple[bool, str]: log.info('Add new settings') message = '' with self.lock: if 'eth_rpc_endpoint' in settings: result, msg = self.blockchain.set_eth_rpc_endpoint( settings['eth_rpc_endpoint']) if not result: # Don't save it in the DB del settings['eth_rpc_endpoint'] message += "\nEthereum RPC endpoint not set: " + msg if 'main_currency' in settings: given_symbol = settings['main_currency'] try: main_currency = Asset(given_symbol) except UnknownAsset: return False, f'Unknown fiat currency {given_symbol} provided' except DeserializationError: return False, 'Non string type given for fiat currency' if not main_currency.is_fiat(): msg = ( f'Provided symbol for main currency {given_symbol} is ' f'not a fiat currency') return False, msg res, msg = self.accountant.customize(settings) if not res: message += '\n' + msg return False, message self.data.set_settings(settings, self.accountant) # Always return success here but with a message return True, message def setup_exchange( self, name: str, api_key: str, api_secret: str, ) -> Tuple[bool, str]: """ Setup a new exchange with an api key and an api secret By default the api keys are always validated unless validate is False. """ is_success, msg = self.exchange_manager.setup_exchange( name=name, api_key=api_key, api_secret=api_secret, database=self.data.db, ) if is_success: # Success, save the result in the DB self.data.db.add_exchange(name, api_key, api_secret) return is_success, msg def remove_exchange(self, name: str) -> Tuple[bool, str]: if not self.exchange_manager.has_exchange(name): return False, 'Exchange {} is not registered'.format(name) self.exchange_manager.delete_exchange(name) # Success, remove it also from the DB self.data.db.remove_exchange(name) return True, '' def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]: """Query for frequently changing data""" result: Dict[str, Union[bool, Timestamp]] = {} if self.user_is_logged_in: result[ 'last_balance_save'] = self.data.db.get_last_balance_save_time( ) result['eth_node_connection'] = self.blockchain.ethchain.connected result[ 'history_process_start_ts'] = self.accountant.started_processing_timestamp result[ 'history_process_current_ts'] = self.accountant.currently_processing_timestamp return result def shutdown(self) -> None: self.logout() self.shutdown_event.set()
def _init(self, maxsize): self._size = 0 self._semaphore = Semaphore(1) self._lock = Lock() self.task_queue = Queue() self._set_maxsize(maxsize)
def __init__(self): self._lock = Semaphore() self._current = None self._count = 0
if ex.code != 400: raise log.debug('Username taken. Continuing') continue else: raise ValueError('Could not register or login!') name = encode_hex(signer.sign(client.user_id.encode())) user = client.get_user(client.user_id) user.set_display_name(name) return user @cached(cache=LRUCache(128), key=attrgetter('user_id', 'displayname'), lock=Semaphore()) def validate_userid_signature(user: User) -> Optional[Address]: """ Validate a userId format and signature on displayName, and return its address""" # display_name should be an address in the USERID_RE format match = USERID_RE.match(user.user_id) if not match: return None encoded_address = match.group(1) address: Address = to_canonical_address(encoded_address) try: displayname = user.get_display_name() recovered = recover( data=user.user_id.encode(), signature=decode_hex(displayname),
def __init__(self, web3: Web3, filter_params: dict): super().__init__(web3, filter_id=None) self.filter_params = filter_params self._last_block: int = -1 self._lock = Semaphore()
class RaidenService(Runnable): """ A Raiden node. """ def __init__( self, chain: BlockChainService, query_start_block: typing.BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, private_key_bin, transport, raiden_event_handler, config, discovery=None, ): super().__init__() if not isinstance(private_key_bin, bytes) or len(private_key_bin) != 32: raise ValueError('invalid private_key') self.tokennetworkids_to_connectionmanagers = dict() self.identifier_to_results: typing.Dict[ typing.PaymentID, AsyncResult, ] = dict() self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_secret_registry = default_secret_registry self.config = config self.privkey = private_key_bin self.address = privatekey_to_address(private_key_bin) self.discovery = discovery self.private_key = PrivateKey(private_key_bin) self.pubkey = self.private_key.public_key.format(compressed=False) self.transport = transport self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.wal = None self.snapshot_group = 0 # This flag will be used to prevent the service from processing # state changes events until we know that pending transactions # have been dispatched. self.dispatch_events_lock = Semaphore(1) self.database_path = config['database_path'] if self.database_path != ':memory:': database_dir = os.path.dirname(config['database_path']) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Prevent concurrent access to the same db self.lock_file = os.path.join(self.database_dir, '.lock') self.db_lock = filelock.FileLock(self.lock_file) else: self.database_path = ':memory:' self.database_dir = None self.lock_file = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() def start(self): """ Start the node synchronously. Raises directly if anything went wrong on startup """ if not self.stop_event.ready(): raise RuntimeError(f'{self!r} already started') self.stop_event.clear() if self.database_dir is not None: self.db_lock.acquire(timeout=0) assert self.db_lock.is_locked # start the registration early to speed up the start if self.config['transport_type'] == 'udp': endpoint_registration_greenlet = gevent.spawn( self.discovery.register, self.address, self.config['transport']['udp']['external_ip'], self.config['transport']['udp']['external_port'], ) storage = sqlite.SQLiteStorage(self.database_path, serialize.JSONSerializer()) self.wal = wal.restore_to_state_change( transition_function=node.state_transition, storage=storage, state_change_identifier='latest', ) if self.wal.state_manager.current_state is None: log.debug( 'No recoverable state available, created inital state', node=pex(self.address), ) block_number = self.chain.block_number() state_change = ActionInitChain( random.Random(), block_number, self.chain.node_address, self.chain.network_id, ) self.wal.log_and_dispatch(state_change) payment_network = PaymentNetworkState( self.default_registry.address, [], # empty list of token network states as it's the node's startup ) state_change = ContractReceiveNewPaymentNetwork( constants.EMPTY_HASH, payment_network, ) self.handle_state_change(state_change) # On first run Raiden needs to fetch all events for the payment # network, to reconstruct all token network graphs and find opened # channels last_log_block_number = 0 else: # The `Block` state change is dispatched only after all the events # for that given block have been processed, filters can be safely # installed starting from this position without losing events. last_log_block_number = views.block_number(self.wal.state_manager.current_state) log.debug( 'Restored state from WAL', last_restored_block=last_log_block_number, node=pex(self.address), ) known_networks = views.get_payment_network_identifiers(views.state_from_raiden(self)) if known_networks and self.default_registry.address not in known_networks: configured_registry = pex(self.default_registry.address) known_registries = lpex(known_networks) raise RuntimeError( f'Token network address mismatch.\n' f'Raiden is configured to use the smart contract ' f'{configured_registry}, which conflicts with the current known ' f'smart contracts {known_registries}', ) # Clear ref cache & disable caching serialize.RaidenJSONDecoder.ref_cache.clear() serialize.RaidenJSONDecoder.cache_object_references = False # Restore the current snapshot group state_change_qty = self.wal.storage.count_state_changes() self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT # Install the filters using the correct from_block value, otherwise # blockchain logs can be lost. self.install_all_blockchain_filters( self.default_registry, self.default_secret_registry, last_log_block_number, ) # Complete the first_run of the alarm task and synchronize with the # blockchain since the last run. # # Notes about setup order: # - The filters must be polled after the node state has been primed, # otherwise the state changes won't have effect. # - The alarm must complete its first run before the transport is started, # to avoid rejecting messages for unknown channels. self.alarm.register_callback(self._callback_new_block) # alarm.first_run may process some new channel, which would start_health_check_for # a partner, that's why transport needs to be already started at this point self.transport.start(self) self.alarm.first_run() chain_state = views.state_from_raiden(self) # Dispatch pending transactions pending_transactions = views.get_pending_transactions( chain_state, ) log.debug( 'Processing pending transactions', num_pending_transactions=len(pending_transactions), node=pex(self.address), ) with self.dispatch_events_lock: for transaction in pending_transactions: try: self.raiden_event_handler.on_raiden_event(self, transaction) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData as e: raise except RaidenUnrecoverableError as e: if self.config['network_type'] == NetworkType.MAIN: log.error(str(e)) else: raise self.alarm.start() # after transport and alarm is started, send queued messages events_queues = views.get_all_messagequeues(chain_state) for queue_identifier, event_queue in events_queues.items(): self.start_health_check_for(queue_identifier.recipient) # repopulate identifier_to_results for pending transfers for event in event_queue: if type(event) == SendDirectTransfer: self.identifier_to_results[event.payment_identifier] = AsyncResult() message = message_from_sendevent(event, self.address) self.sign(message) self.transport.send_async(queue_identifier, message) # exceptions on these subtasks should crash the app and bubble up self.alarm.link_exception(self.on_error) self.transport.link_exception(self.on_error) # Health check needs the transport layer self.start_neighbours_healthcheck() if self.config['transport_type'] == 'udp': endpoint_registration_greenlet.get() # re-raise if exception occurred super().start() def _run(self): """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """ try: self.stop_event.wait() except gevent.GreenletExit: # killed without exception self.stop_event.set() gevent.killall([self.alarm, self.transport]) # kill children raise # re-raise to keep killed status except Exception: self.stop() raise def stop(self): """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """ if self.stop_event.ready(): # not started return # Needs to come before any greenlets joining self.stop_event.set() # Filters must be uninstalled after the alarm task has stopped. Since # the events are polled by an alarm task callback, if the filters are # uninstalled before the alarm task is fully stopped the callback # `poll_blockchain_events` will fail. # # We need a timeout to prevent an endless loop from trying to # contact the disconnected client self.transport.stop() self.alarm.stop() self.transport.join() self.alarm.join() self.blockchain_events.uninstall_all_event_listeners() if self.db_lock is not None: self.db_lock.release() def add_pending_greenlet(self, greenlet: gevent.Greenlet): greenlet.link_exception(self.on_error) def __repr__(self): return '<{} {}>'.format(self.__class__.__name__, pex(self.address)) def start_neighbours_healthcheck(self): for neighbour in views.all_neighbour_nodes(self.wal.state_manager.current_state): if neighbour != ConnectionManager.BOOTSTRAP_ADDR: self.start_health_check_for(neighbour) def get_block_number(self): return views.block_number(self.wal.state_manager.current_state) def handle_state_change(self, state_change): log.debug('STATE CHANGE', node=pex(self.address), state_change=state_change) event_list = self.wal.log_and_dispatch(state_change) if self.dispatch_events_lock.locked(): return [] for event in event_list: log.debug('RAIDEN EVENT', node=pex(self.address), raiden_event=event) try: self.raiden_event_handler.on_raiden_event( raiden=self, event=event, ) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData as e: raise except RaidenUnrecoverableError as e: if self.config['network_type'] == NetworkType.MAIN: log.error(str(e)) else: raise # Take a snapshot every SNAPSHOT_STATE_CHANGES_COUNT # TODO: Gather more data about storage requirements # and update the value to specify how often we need # capturing a snapshot should take place new_snapshot_group = self.wal.storage.count_state_changes() // SNAPSHOT_STATE_CHANGES_COUNT if new_snapshot_group > self.snapshot_group: log.debug(f'Storing snapshot: {new_snapshot_group}') self.wal.snapshot() self.snapshot_group = new_snapshot_group return event_list def set_node_network_state(self, node_address, network_state): state_change = ActionChangeNodeNetworkState(node_address, network_state) self.wal.log_and_dispatch(state_change) def start_health_check_for(self, node_address): self.transport.start_health_check(node_address) def _callback_new_block(self, latest_block): """Called once a new block is detected by the alarm task. Note: This should be called only once per block, otherwise there will be duplicated `Block` state changes in the log. Therefore this method should be called only once a new block is mined with the corresponding block data from the AlarmTask. """ # User facing APIs, which have on-chain side-effects, force polled the # blockchain to update the node's state. This force poll is used to # provide a consistent view to the user, e.g. a channel open call waits # for the transaction to be mined and force polled the event to update # the node's state. This pattern introduced a race with the alarm task # and the task which served the user request, because the events are # returned only once per filter. The lock below is to protect against # these races (introduced by the commit # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d) with self.event_poll_lock: latest_block_number = latest_block['number'] for event in self.blockchain_events.poll_blockchain_events(latest_block_number): # These state changes will be procesed with a block_number # which is /larger/ than the ChainState's block_number. on_blockchain_event(self, event) # On restart the Raiden node will re-create the filters with the # ethereum node. These filters will have the from_block set to the # value of the latest Block state change. To avoid missing events # the Block state change is dispatched only after all of the events # have been processed. # # This means on some corner cases a few events may be applied # twice, this will happen if the node crashed and some events have # been processed but the Block state change has not been # dispatched. state_change = Block( block_number=latest_block_number, gas_limit=latest_block['gasLimit'], block_hash=bytes(latest_block['hash']), ) self.handle_state_change(state_change) def sign(self, message): """ Sign message inplace. """ if not isinstance(message, SignedMessage): raise ValueError('{} is not signable.'.format(repr(message))) message.sign(self.private_key) def install_all_blockchain_filters( self, token_network_registry_proxy: TokenNetworkRegistry, secret_registry_proxy: SecretRegistry, from_block: typing.BlockNumber, ): with self.event_poll_lock: node_state = views.state_from_raiden(self) token_networks = views.get_token_network_identifiers( node_state, token_network_registry_proxy.address, ) self.blockchain_events.add_token_network_registry_listener( token_network_registry_proxy, from_block, ) self.blockchain_events.add_secret_registry_listener( secret_registry_proxy, from_block, ) for token_network in token_networks: token_network_proxy = self.chain.token_network(token_network) self.blockchain_events.add_token_network_listener( token_network_proxy, from_block, ) def connection_manager_for_token_network(self, token_network_identifier): if not is_binary_address(token_network_identifier): raise InvalidAddress('token address is not valid.') known_token_networks = views.get_token_network_identifiers( views.state_from_raiden(self), self.default_registry.address, ) if token_network_identifier not in known_token_networks: raise InvalidAddress('token is not registered.') manager = self.tokennetworkids_to_connectionmanagers.get(token_network_identifier) if manager is None: manager = ConnectionManager(self, token_network_identifier) self.tokennetworkids_to_connectionmanagers[token_network_identifier] = manager return manager def leave_all_token_networks(self): state_change = ActionLeaveAllNetworks() self.wal.log_and_dispatch(state_change) def close_and_settle(self): log.info('raiden will close and settle all channels now') self.leave_all_token_networks() connection_managers = [cm for cm in self.tokennetworkids_to_connectionmanagers.values()] if connection_managers: waiting.wait_for_settle_all_channels( self, self.alarm.sleep_time, ) def mediated_transfer_async( self, token_network_identifier: typing.TokenNetworkID, amount: typing.TokenAmount, target: typing.Address, identifier: typing.PaymentID, ): """ Transfer `amount` between this node and `target`. This method will start an asynchronous transfer, the transfer might fail or succeed depending on a couple of factors: - Existence of a path that can be used, through the usage of direct or intermediary channels. - Network speed, making the transfer sufficiently fast so it doesn't expire. """ async_result = self.start_mediated_transfer( token_network_identifier, amount, target, identifier, ) return async_result def direct_transfer_async(self, token_network_identifier, amount, target, identifier): """ Do a direct transfer with target. Direct transfers are non cancellable and non expirable, since these transfers are a signed balance proof with the transferred amount incremented. Because the transfer is non cancellable, there is a level of trust with the target. After the message is sent the target is effectively paid and then it is not possible to revert. The async result will be set to False iff there is no direct channel with the target or the payer does not have balance to complete the transfer, otherwise because the transfer is non expirable the async result *will never be set to False* and if the message is sent it will hang until the target node acknowledge the message. This transfer should be used as an optimization, since only two packets are required to complete the transfer (from the payers perspective), whereas the mediated transfer requires 6 messages. """ self.start_health_check_for(target) if identifier is None: identifier = create_default_identifier() direct_transfer = ActionTransferDirect( token_network_identifier, target, identifier, amount, ) async_result = AsyncResult() self.identifier_to_results[identifier] = async_result self.handle_state_change(direct_transfer) def start_mediated_transfer( self, token_network_identifier: typing.TokenNetworkID, amount: typing.TokenAmount, target: typing.Address, identifier: typing.PaymentID, ): self.start_health_check_for(target) if identifier is None: identifier = create_default_identifier() if identifier in self.identifier_to_results: return self.identifier_to_results[identifier] async_result = AsyncResult() self.identifier_to_results[identifier] = async_result secret = random_secret() init_initiator_statechange = initiator_init( self, identifier, amount, secret, token_network_identifier, target, ) # Dispatch the state change even if there are no routes to create the # wal entry. self.handle_state_change(init_initiator_statechange) return async_result def mediate_mediated_transfer(self, transfer: LockedTransfer): init_mediator_statechange = mediator_init(self, transfer) self.handle_state_change(init_mediator_statechange) def target_mediated_transfer(self, transfer: LockedTransfer): self.start_health_check_for(transfer.initiator) init_target_statechange = target_init(transfer) self.handle_state_change(init_target_statechange)
def __init__( self, chain: BlockChainService, query_start_block: BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, private_key_bin, transport, raiden_event_handler, message_handler, config, discovery=None, ): super().__init__() if not isinstance(private_key_bin, bytes) or len(private_key_bin) != 32: raise ValueError('invalid private_key') self.tokennetworkids_to_connectionmanagers = dict() self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict( dict) self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_secret_registry = default_secret_registry self.config = config self.privkey = private_key_bin self.address = privatekey_to_address(private_key_bin) self.discovery = discovery self.private_key = PrivateKey(private_key_bin) self.pubkey = self.private_key.public_key.format(compressed=False) self.transport = transport self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.message_handler = message_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.wal = None self.snapshot_group = 0 # This flag will be used to prevent the service from processing # state changes events until we know that pending transactions # have been dispatched. self.dispatch_events_lock = Semaphore(1) self.contract_manager = ContractManager(config['contracts_path']) self.database_path = config['database_path'] if self.database_path != ':memory:': database_dir = os.path.dirname(config['database_path']) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Prevent concurrent access to the same db self.lock_file = os.path.join(self.database_dir, '.lock') self.db_lock = filelock.FileLock(self.lock_file) else: self.database_path = ':memory:' self.database_dir = None self.lock_file = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() self.gas_reserve_lock = gevent.lock.Semaphore() self.payment_identifier_lock = gevent.lock.Semaphore()
def __init__(self, *args, **kwargs): # The raw (non-gevent) socket, if possible self._socket = None BaseServer.__init__(self, *args, **kwargs) from gevent.lock import Semaphore self._writelock = Semaphore()
def __init__(self) -> None: super().__init__() self.query_locks_map: Dict[int, Semaphore] = defaultdict(Semaphore) # Accessing and writing to the query_locks map also needs to be protected self.query_locks_map_lock = Semaphore()
# for i in range(5): # print("I am fun 2 this is %s"%i) # gevent.sleep(0) # # # fun1() # # fun2() # # t1 = gevent.spawn(fun1) # t2 = gevent.spawn(fun2) # # gevent.joinall([t1,t2]) import gevent from gevent.lock import Semaphore sem = Semaphore(1) def fun1(): for i in range(5): sem.acquire() print("I am fun 1 this is %s"%i) sem.release() def fun2(): for i in range(5): sem.acquire() print("I am fun 2 this is %s"%i) sem.release() # fun1() # fun2()
def __init__(self, name, task): self.name = name self.task = task self.lock = Semaphore(task.max_concurrent)
class Queue(Greenlet): """Manages the queue of |Envelope| objects waiting for delivery. This is not a standard FIFO queue, a message's place in the queue depends entirely on the timestamp of its next delivery attempt. :param store: Object implementing :class:`QueueStorage`. :param relay: |Relay| object used to attempt message deliveries. If this is not given, no deliveries will be attempted on received messages. :param backoff: Function that, given an |Envelope| and number of delivery attempts, will return the number of seconds before the next attempt. If it returns ``None``, the message will be permanently failed. The default backoff function simply returns ``None`` and messages are never retried. :param bounce_factory: Function that produces a |Bounce| object given the same parameters as the |Bounce| constructor. If the function returns ``None``, no bounce is delivered. By default, a new |Bounce| is created in every case. :param bounce_queue: |Queue| object that will be used for delivering bounce messages. The default is ``self``. :param store_pool: Number of simultaneous operations performable against the ``store`` object. Default is unlimited. :param relay_pool: Number of simultaneous operations performable against the ``relay`` object. Default is unlimited. """ def __init__(self, store, relay=None, backoff=None, bounce_factory=None, bounce_queue=None, store_pool=None, relay_pool=None): super(Queue, self).__init__() self.store = store self.relay = relay self.backoff = backoff or self._default_backoff self.bounce_factory = bounce_factory or Bounce self.bounce_queue = bounce_queue or self self.wake = Event() self.queued = [] self.active_ids = set() self.queued_ids = set() self.queued_lock = Semaphore(1) self.queue_policies = [] self._use_pool('store_pool', store_pool) self._use_pool('relay_pool', relay_pool) def add_policy(self, policy): """Adds a |QueuePolicy| to be executed before messages are persisted to storage. :param policy: |QueuePolicy| object to execute. """ if isinstance(policy, QueuePolicy): self.queue_policies.append(policy) else: raise TypeError('Argument not a QueuePolicy.') @staticmethod def _default_backoff(envelope, attempts): pass def _run_policies(self, envelope): results = [envelope] def recurse(current, i): try: policy = self.queue_policies[i] except IndexError: return ret = policy.apply(current) if ret: results.remove(current) results.extend(ret) for env in ret: recurse(env, i + 1) else: recurse(current, i + 1) recurse(envelope, 0) return results def _use_pool(self, attr, pool): if pool is None: pass elif isinstance(pool, Pool): setattr(self, attr, pool) else: setattr(self, attr, Pool(pool)) def _pool_run(self, which, func, *args, **kwargs): pool = getattr(self, which + '_pool', None) if pool: ret = pool.spawn(func, *args, **kwargs) return ret.get() else: return func(*args, **kwargs) def _pool_imap(self, which, func, *iterables): pool = getattr(self, which + '_pool', gevent) threads = imap(pool.spawn, repeat(func), *iterables) ret = [] for thread in threads: thread.join() ret.append(thread.exception or thread.value) return ret def _pool_spawn(self, which, func, *args, **kwargs): pool = getattr(self, which + '_pool', gevent) return pool.spawn(func, *args, **kwargs) def _add_queued(self, entry): timestamp, id = entry if id not in self.queued_ids | self.active_ids: bisect.insort(self.queued, entry) self.queued_ids.add(id) self.wake.set() def enqueue(self, envelope): """Drops a new message in the queue for delivery. The first delivery attempt is made immediately (depending on relay pool availability). This method is not typically called directly, |Edge| objects use it when they receive new messages. :param envelope: |Envelope| object to enqueue. :returns: Zipped list of envelopes and their respective queue IDs (or thrown :exc:`QueueError` objects). """ now = time.time() envelopes = self._run_policies(envelope) ids = self._pool_imap('store', self.store.write, envelopes, repeat(now)) results = list(zip(envelopes, ids)) for env, id in results: if not isinstance(id, BaseException): if self.relay: self.active_ids.add(id) self._pool_spawn('relay', self._attempt, id, env, 0) elif not isinstance(id, QueueError): raise id # Re-raise exceptions that are not QueueError. return results def _load_all(self): for entry in self.store.load(): self._add_queued(entry) def _remove(self, id): self._pool_spawn('store', self.store.remove, id) self.queued_ids.discard(id) self.active_ids.discard(id) def _bounce(self, envelope, reply): bounce = self.bounce_factory(envelope, reply) if bounce: return self.bounce_queue.enqueue(bounce) def _perm_fail(self, id, envelope, reply): if id is not None: self._remove(id) if envelope.sender: # Can't bounce to null-sender. self._pool_spawn('bounce', self._bounce, envelope, reply) def _split_by_reply(self, envelope, replies): if isinstance(replies, Reply): return [(replies, envelope)] groups = [] for i, rcpt in enumerate(envelope.recipients): for reply, group_env in groups: if replies[i] == reply: group_env.recipients.append(rcpt) break else: group_env = envelope.copy([rcpt]) groups.append((replies[i], group_env)) return groups def _retry_later(self, id, envelope, replies): attempts = self.store.increment_attempts(id) wait = self.backoff(envelope, attempts) if wait is None: for reply, group_env in self._split_by_reply(envelope, replies): reply.message += ' (Too many retries)' self._perm_fail(None, group_env, reply) self._remove(id) return False else: when = time.time() + wait self.store.set_timestamp(id, when) self.active_ids.discard(id) self._add_queued((when, id)) return True def _attempt(self, id, envelope, attempts): try: results = self.relay._attempt(envelope, attempts) except TransientRelayError as e: self._pool_spawn('store', self._retry_later, id, envelope, e.reply) except PermanentRelayError as e: self._perm_fail(id, envelope, e.reply) except Exception as e: log_exception(__name__) reply = Reply('450', '4.0.0 Unhandled delivery error: ' + str(e)) self._pool_spawn('store', self._retry_later, id, envelope, reply) raise else: if isinstance(results, collections.Mapping): self._handle_partial_relay(id, envelope, attempts, results) elif isinstance(results, collections.Sequence): results = dict(zip(envelope.recipients, results)) self._handle_partial_relay(id, envelope, attempts, results) else: self._remove(id) def _handle_partial_relay(self, id, envelope, attempts, results): delivered = set() tempfails = [] permfails = [] for rcpt, rcpt_res in results.items(): if rcpt_res is None or isinstance(rcpt_res, Reply): delivered.add(envelope.recipients.index(rcpt)) elif isinstance(rcpt_res, PermanentRelayError): delivered.add(envelope.recipients.index(rcpt)) permfails.append((rcpt, rcpt_res.reply)) elif isinstance(rcpt_res, TransientRelayError): tempfails.append((rcpt, rcpt_res.reply)) if permfails: rcpts, replies = zip(*permfails) fail_env = envelope.copy(rcpts) for reply, group_env in self._split_by_reply(fail_env, replies): self._perm_fail(None, group_env, reply) if tempfails: rcpts, replies = zip(*tempfails) fail_env = envelope.copy(rcpts) if not self._retry_later(id, fail_env, replies): return else: self.store.remove(id) return self.store.set_recipients_delivered(id, delivered) def _dequeue(self, id): try: envelope, attempts = self.store.get(id) except KeyError: return self.active_ids.add(id) self._pool_spawn('relay', self._attempt, id, envelope, attempts) def _check_ready(self, now): last_i = 0 for i, entry in enumerate(self.queued): timestamp, entry_id = entry if now >= timestamp: self._pool_spawn('store', self._dequeue, entry_id) last_i = i + 1 else: break if last_i > 0: self.queued = self.queued[last_i:] self.queued_ids = set([id for _, id in self.queued]) def _wait_store(self): while True: try: for entry in self.store.wait(): self._add_queued(entry) except NotImplementedError: return def _wait_ready(self, now): try: first = self.queued[0] except IndexError: self.wake.wait() self.wake.clear() return first_timestamp = first[0] if first_timestamp > now: self.wake.wait(first_timestamp - now) self.wake.clear() def flush(self): """Attempts to immediately flush all messages waiting in the queue, regardless of their retry timers. .. warning:: This can be a very expensive operation, use with care. """ self.wake.set() self.wake.clear() self.queued_lock.acquire() try: for entry in self.queued: self._pool_spawn('store', self._dequeue, entry[1]) self.queued = [] finally: self.queued_lock.release() def kill(self): """This method is used by |Queue| and |Queue|-like objects to properly end any associated services (such as running :class:`~gevent.Greenlet` threads) and close resources. """ super(Queue, self).kill() def _run(self): if not self.relay: return self._pool_spawn('store', self._load_all) self._pool_spawn('store', self._wait_store) while True: self.queued_lock.acquire() try: now = time.time() self._check_ready(now) self._wait_ready(now) finally: self.queued_lock.release()
class RaidenService(Runnable): """ A Raiden node. """ def __init__( self, chain: BlockChainService, query_start_block: BlockNumber, default_registry: TokenNetworkRegistry, default_secret_registry: SecretRegistry, transport, raiden_event_handler, message_handler, config, discovery=None, ): super().__init__() self.tokennetworkids_to_connectionmanagers = dict() self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict(dict) self.chain: BlockChainService = chain self.default_registry = default_registry self.query_start_block = query_start_block self.default_secret_registry = default_secret_registry self.config = config self.signer: Signer = LocalSigner(self.chain.client.privkey) self.address = self.signer.address self.discovery = discovery self.transport = transport self.blockchain_events = BlockchainEvents() self.alarm = AlarmTask(chain) self.raiden_event_handler = raiden_event_handler self.message_handler = message_handler self.stop_event = Event() self.stop_event.set() # inits as stopped self.wal = None self.snapshot_group = 0 # This flag will be used to prevent the service from processing # state changes events until we know that pending transactions # have been dispatched. self.dispatch_events_lock = Semaphore(1) self.contract_manager = ContractManager(config['contracts_path']) self.database_path = config['database_path'] if self.database_path != ':memory:': database_dir = os.path.dirname(config['database_path']) os.makedirs(database_dir, exist_ok=True) self.database_dir = database_dir # Two raiden processes must not write to the same database, even # though the database itself may be consistent. If more than one # nodes writes state changes to the same WAL there are no # guarantees about recovery, this happens because during recovery # the WAL replay can not be deterministic. self.lock_file = os.path.join(self.database_dir, '.lock') self.db_lock = filelock.FileLock(self.lock_file) else: self.database_path = ':memory:' self.database_dir = None self.lock_file = None self.serialization_file = None self.db_lock = None self.event_poll_lock = gevent.lock.Semaphore() self.gas_reserve_lock = gevent.lock.Semaphore() self.payment_identifier_lock = gevent.lock.Semaphore() def start(self): """ Start the node synchronously. Raises directly if anything went wrong on startup """ if not self.stop_event.ready(): raise RuntimeError(f'{self!r} already started') self.stop_event.clear() if self.database_dir is not None: self.db_lock.acquire(timeout=0) assert self.db_lock.is_locked # start the registration early to speed up the start if self.config['transport_type'] == 'udp': endpoint_registration_greenlet = gevent.spawn( self.discovery.register, self.address, self.config['transport']['udp']['external_ip'], self.config['transport']['udp']['external_port'], ) self.maybe_upgrade_db() storage = sqlite.SerializedSQLiteStorage( database_path=self.database_path, serializer=serialize.JSONSerializer(), ) storage.log_run() self.wal = wal.restore_to_state_change( transition_function=node.state_transition, storage=storage, state_change_identifier='latest', ) if self.wal.state_manager.current_state is None: log.debug( 'No recoverable state available, created inital state', node=pex(self.address), ) # On first run Raiden needs to fetch all events for the payment # network, to reconstruct all token network graphs and find opened # channels last_log_block_number = self.query_start_block state_change = ActionInitChain( random.Random(), last_log_block_number, self.chain.node_address, self.chain.network_id, ) self.handle_state_change(state_change) payment_network = PaymentNetworkState( self.default_registry.address, [], # empty list of token network states as it's the node's startup ) state_change = ContractReceiveNewPaymentNetwork( constants.EMPTY_HASH, payment_network, last_log_block_number, ) self.handle_state_change(state_change) else: # The `Block` state change is dispatched only after all the events # for that given block have been processed, filters can be safely # installed starting from this position without losing events. last_log_block_number = views.block_number(self.wal.state_manager.current_state) log.debug( 'Restored state from WAL', last_restored_block=last_log_block_number, node=pex(self.address), ) known_networks = views.get_payment_network_identifiers(views.state_from_raiden(self)) if known_networks and self.default_registry.address not in known_networks: configured_registry = pex(self.default_registry.address) known_registries = lpex(known_networks) raise RuntimeError( f'Token network address mismatch.\n' f'Raiden is configured to use the smart contract ' f'{configured_registry}, which conflicts with the current known ' f'smart contracts {known_registries}', ) # Restore the current snapshot group state_change_qty = self.wal.storage.count_state_changes() self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT # Install the filters using the correct from_block value, otherwise # blockchain logs can be lost. self.install_all_blockchain_filters( self.default_registry, self.default_secret_registry, last_log_block_number, ) # Complete the first_run of the alarm task and synchronize with the # blockchain since the last run. # # Notes about setup order: # - The filters must be polled after the node state has been primed, # otherwise the state changes won't have effect. # - The alarm must complete its first run before the transport is started, # to reject messages for closed/settled channels. self.alarm.register_callback(self._callback_new_block) with self.dispatch_events_lock: self.alarm.first_run(last_log_block_number) chain_state = views.state_from_raiden(self) self._initialize_transactions_queues(chain_state) self._initialize_whitelists(chain_state) # send messages in queue before starting transport, # this is necessary to avoid a race where, if the transport is started # before the messages are queued, actions triggered by it can cause new # messages to be enqueued before these older ones self._initialize_messages_queues(chain_state) # The transport must not ever be started before the alarm task's # `first_run()` has been, because it's this method which synchronizes the # node with the blockchain, including the channel's state (if the channel # is closed on-chain new messages must be rejected, which will not be the # case if the node is not synchronized) self.transport.start( raiden_service=self, message_handler=self.message_handler, prev_auth_data=chain_state.last_transport_authdata, ) # First run has been called above! self.alarm.start() # exceptions on these subtasks should crash the app and bubble up self.alarm.link_exception(self.on_error) self.transport.link_exception(self.on_error) # Health check needs the transport layer self.start_neighbours_healthcheck(chain_state) if self.config['transport_type'] == 'udp': endpoint_registration_greenlet.get() # re-raise if exception occurred log.debug('Raiden Service started', node=pex(self.address)) super().start() def _run(self, *args, **kwargs): # pylint: disable=method-hidden """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """ try: self.stop_event.wait() except gevent.GreenletExit: # killed without exception self.stop_event.set() gevent.killall([self.alarm, self.transport]) # kill children raise # re-raise to keep killed status except Exception: self.stop() raise def stop(self): """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """ if self.stop_event.ready(): # not started return # Needs to come before any greenlets joining self.stop_event.set() # Filters must be uninstalled after the alarm task has stopped. Since # the events are polled by an alarm task callback, if the filters are # uninstalled before the alarm task is fully stopped the callback # `poll_blockchain_events` will fail. # # We need a timeout to prevent an endless loop from trying to # contact the disconnected client self.transport.stop() self.alarm.stop() self.transport.join() self.alarm.join() self.blockchain_events.uninstall_all_event_listeners() if self.db_lock is not None: self.db_lock.release() log.debug('Raiden Service stopped', node=pex(self.address)) def add_pending_greenlet(self, greenlet: gevent.Greenlet): greenlet.link_exception(self.on_error) def __repr__(self): return '<{} {}>'.format(self.__class__.__name__, pex(self.address)) def start_neighbours_healthcheck(self, chain_state: ChainState): for neighbour in views.all_neighbour_nodes(chain_state): if neighbour != ConnectionManager.BOOTSTRAP_ADDR: self.start_health_check_for(neighbour) def get_block_number(self) -> BlockNumber: return views.block_number(self.wal.state_manager.current_state) def on_message(self, message: Message): self.message_handler.on_message(self, message) def handle_state_change(self, state_change: StateChange): log.debug( 'State change', node=pex(self.address), state_change=_redact_secret(serialize.JSONSerializer.serialize(state_change)), ) event_list = self.wal.log_and_dispatch(state_change) if self.dispatch_events_lock.locked(): return [] for event in event_list: log.debug( 'Raiden event', node=pex(self.address), raiden_event=_redact_secret(serialize.JSONSerializer.serialize(event)), ) try: self.raiden_event_handler.on_raiden_event( raiden=self, event=event, ) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config['environment_type'] == Environment.PRODUCTION and not self.config['unrecoverable_error_should_crash'] ) if log_unrecoverable: log.error(str(e)) else: raise # Take a snapshot every SNAPSHOT_STATE_CHANGES_COUNT # TODO: Gather more data about storage requirements # and update the value to specify how often we need # capturing a snapshot should take place new_snapshot_group = self.wal.storage.count_state_changes() // SNAPSHOT_STATE_CHANGES_COUNT if new_snapshot_group > self.snapshot_group: log.debug('Storing snapshot', snapshot_id=new_snapshot_group) self.wal.snapshot() self.snapshot_group = new_snapshot_group return event_list def set_node_network_state(self, node_address: Address, network_state: str): state_change = ActionChangeNodeNetworkState(node_address, network_state) self.handle_state_change(state_change) def start_health_check_for(self, node_address: Address): # This function is a noop during initialization. It can be called # through the alarm task while polling for new channel events. The # healthcheck will be started by self.start_neighbours_healthcheck() if self.transport: self.transport.start_health_check(node_address) def _callback_new_block(self, latest_block: Dict): """Called once a new block is detected by the alarm task. Note: This should be called only once per block, otherwise there will be duplicated `Block` state changes in the log. Therefore this method should be called only once a new block is mined with the corresponding block data from the AlarmTask. """ # User facing APIs, which have on-chain side-effects, force polled the # blockchain to update the node's state. This force poll is used to # provide a consistent view to the user, e.g. a channel open call waits # for the transaction to be mined and force polled the event to update # the node's state. This pattern introduced a race with the alarm task # and the task which served the user request, because the events are # returned only once per filter. The lock below is to protect against # these races (introduced by the commit # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d) with self.event_poll_lock: latest_block_number = latest_block['number'] confirmation_blocks = self.config['blockchain']['confirmation_blocks'] confirmed_block_number = latest_block_number - confirmation_blocks confirmed_block = self.chain.client.web3.eth.getBlock(confirmed_block_number) # handle testing private chains confirmed_block_number = max(GENESIS_BLOCK_NUMBER, confirmed_block_number) for event in self.blockchain_events.poll_blockchain_events(confirmed_block_number): # These state changes will be procesed with a block_number # which is /larger/ than the ChainState's block_number. on_blockchain_event(self, event) # On restart the Raiden node will re-create the filters with the # ethereum node. These filters will have the from_block set to the # value of the latest Block state change. To avoid missing events # the Block state change is dispatched only after all of the events # have been processed. # # This means on some corner cases a few events may be applied # twice, this will happen if the node crashed and some events have # been processed but the Block state change has not been # dispatched. state_change = Block( block_number=confirmed_block_number, gas_limit=confirmed_block['gasLimit'], block_hash=bytes(confirmed_block['hash']), ) self.handle_state_change(state_change) def _register_payment_status( self, target: TargetAddress, identifier: PaymentID, balance_proof: BalanceProofUnsignedState, ): with self.payment_identifier_lock: self.targets_to_identifiers_to_statuses[target][identifier] = PaymentStatus( payment_identifier=identifier, amount=balance_proof.transferred_amount, token_network_identifier=balance_proof.token_network_identifier, payment_done=AsyncResult(), ) def _initialize_transactions_queues(self, chain_state: ChainState): pending_transactions = views.get_pending_transactions(chain_state) log.debug( 'Processing pending transactions', num_pending_transactions=len(pending_transactions), node=pex(self.address), ) with self.dispatch_events_lock: for transaction in pending_transactions: try: self.raiden_event_handler.on_raiden_event(self, transaction) except RaidenRecoverableError as e: log.error(str(e)) except InvalidDBData: raise except RaidenUnrecoverableError as e: log_unrecoverable = ( self.config['environment_type'] == Environment.PRODUCTION and not self.config['unrecoverable_error_should_crash'] ) if log_unrecoverable: log.error(str(e)) else: raise def _initialize_messages_queues(self, chain_state: ChainState): """ Push the queues to the transport and populate targets_to_identifiers_to_statuses. """ events_queues = views.get_all_messagequeues(chain_state) for queue_identifier, event_queue in events_queues.items(): self.start_health_check_for(queue_identifier.recipient) for event in event_queue: is_initiator = ( type(event) == SendLockedTransfer and event.transfer.initiator == self.address ) if is_initiator: self._register_payment_status( target=event.transfer.target, identifier=event.transfer.payment_identifier, balance_proof=event.transfer.balance_proof, ) message = message_from_sendevent(event, self.address) self.sign(message) self.transport.send_async(queue_identifier, message) def _initialize_whitelists(self, chain_state: ChainState): """ Whitelist neighbors and mediated transfer targets on transport """ for neighbour in views.all_neighbour_nodes(chain_state): if neighbour == ConnectionManager.BOOTSTRAP_ADDR: continue self.transport.whitelist(neighbour) events_queues = views.get_all_messagequeues(chain_state) for event_queue in events_queues.values(): for event in event_queue: is_initiator = ( type(event) == SendLockedTransfer and event.transfer.initiator == self.address ) if is_initiator: self.transport.whitelist(address=event.transfer.target) def sign(self, message: Message): """ Sign message inplace. """ if not isinstance(message, SignedMessage): raise ValueError('{} is not signable.'.format(repr(message))) message.sign(self.signer) def install_all_blockchain_filters( self, token_network_registry_proxy: TokenNetworkRegistry, secret_registry_proxy: SecretRegistry, from_block: BlockNumber, ): with self.event_poll_lock: node_state = views.state_from_raiden(self) token_networks = views.get_token_network_identifiers( node_state, token_network_registry_proxy.address, ) self.blockchain_events.add_token_network_registry_listener( token_network_registry_proxy=token_network_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) self.blockchain_events.add_secret_registry_listener( secret_registry_proxy=secret_registry_proxy, contract_manager=self.contract_manager, from_block=from_block, ) for token_network in token_networks: token_network_proxy = self.chain.token_network( TokenNetworkAddress(token_network), ) self.blockchain_events.add_token_network_listener( token_network_proxy=token_network_proxy, contract_manager=self.contract_manager, from_block=from_block, ) def connection_manager_for_token_network( self, token_network_identifier: TokenNetworkID, ) -> ConnectionManager: if not is_binary_address(token_network_identifier): raise InvalidAddress('token address is not valid.') known_token_networks = views.get_token_network_identifiers( views.state_from_raiden(self), self.default_registry.address, ) if token_network_identifier not in known_token_networks: raise InvalidAddress('token is not registered.') manager = self.tokennetworkids_to_connectionmanagers.get(token_network_identifier) if manager is None: manager = ConnectionManager(self, token_network_identifier) self.tokennetworkids_to_connectionmanagers[token_network_identifier] = manager return manager def mediated_transfer_async( self, token_network_identifier: TokenNetworkID, amount: TokenAmount, target: TargetAddress, identifier: PaymentID, ) -> PaymentStatus: """ Transfer `amount` between this node and `target`. This method will start an asynchronous transfer, the transfer might fail or succeed depending on a couple of factors: - Existence of a path that can be used, through the usage of direct or intermediary channels. - Network speed, making the transfer sufficiently fast so it doesn't expire. """ secret = random_secret() payment_status = self.start_mediated_transfer_with_secret( token_network_identifier, amount, target, identifier, secret, ) return payment_status def start_mediated_transfer_with_secret( self, token_network_identifier: TokenNetworkID, amount: TokenAmount, target: TargetAddress, identifier: PaymentID, secret: Secret, ) -> PaymentStatus: secret_hash = sha3(secret) # LEFTODO: Supply a proper block id secret_registered = self.default_secret_registry.check_registered( secrethash=secret_hash, block_identifier='latest', ) if secret_registered: raise RaidenUnrecoverableError( f'Attempted to initiate a locked transfer with secrethash {pex(secret_hash)}.' f' That secret is already registered onchain.', ) self.start_health_check_for(Address(target)) if identifier is None: identifier = create_default_identifier() with self.payment_identifier_lock: payment_status = self.targets_to_identifiers_to_statuses[target].get(identifier) if payment_status: payment_status_matches = payment_status.matches( token_network_identifier, amount, ) if not payment_status_matches: raise PaymentConflict( 'Another payment with the same id is in flight', ) return payment_status payment_status = PaymentStatus( payment_identifier=identifier, amount=amount, token_network_identifier=token_network_identifier, payment_done=AsyncResult(), secret=secret, secret_hash=secret_hash, ) self.targets_to_identifiers_to_statuses[target][identifier] = payment_status init_initiator_statechange = initiator_init( raiden=self, transfer_identifier=identifier, transfer_amount=amount, transfer_secret=secret, token_network_identifier=token_network_identifier, target_address=target, ) # Dispatch the state change even if there are no routes to create the # wal entry. self.handle_state_change(init_initiator_statechange) return payment_status def mediate_mediated_transfer(self, transfer: LockedTransfer): init_mediator_statechange = mediator_init(self, transfer) self.handle_state_change(init_mediator_statechange) def target_mediated_transfer(self, transfer: LockedTransfer): self.start_health_check_for(transfer.initiator) init_target_statechange = target_init(transfer) self.handle_state_change(init_target_statechange) def maybe_upgrade_db(self): manager = UpgradeManager(db_filename=self.database_path) manager.run()
def __init__(self, url=None, ie_info=None, *args, **kwargs): super(YoutubeDLInput, self).__init__(None, *args, **kwargs) self._url = url self._ie_info = ie_info self._info = None self._info_lock = Semaphore()