コード例 #1
0
ファイル: gevent.py プロジェクト: CGenie/raven-python
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests=100, *args, **kwargs):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')

        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url, *args, **kwargs)

    def async_send(self, data, headers, success_cb, failure_cb):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(
            super(GeventedHTTPTransport, self).send, data, headers
        ).link(lambda x: self._done(x, success_cb, failure_cb))

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.exception)
コード例 #2
0
ファイル: pool.py プロジェクト: dsuch/gevent
    def __init__(self, size=None, greenlet_class=None):
        """
        Create a new pool.

        A pool is like a group, but the maximum number of members
        is governed by the *size* parameter.

        :keyword int size: If given, this non-negative integer is the
            maximum count of active greenlets that will be allowed in
            this pool. A few values have special significance:

            * ``None`` (the default) places no limit on the number of
              greenlets. This is useful when you need to track, but not limit,
              greenlets, as with :class:`gevent.pywsgi.WSGIServer`. A :class:`Group`
              may be a more efficient way to achieve the same effect.
            * ``0`` creates a pool that can never have any active greenlets. Attempting
              to spawn in this pool will block forever. This is only useful
              if an application uses :meth:`wait_available` with a timeout and checks
              :meth:`free_count` before attempting to spawn.
        """
        if size is not None and size < 0:
            raise ValueError('size must not be negative: %r' % (size, ))
        Group.__init__(self)
        self.size = size
        if greenlet_class is not None:
            self.greenlet_class = greenlet_class
        if size is None:
            self._semaphore = DummySemaphore()
        else:
            self._semaphore = Semaphore(size)
コード例 #3
0
ファイル: job.py プロジェクト: razzfazz/freenas
class JobSharedLock(object):
    """
    Shared lock for jobs.
    Each job method can specify a lock which will be shared
    among all calls for that job and only one job can run at a time
    for this lock.
    """

    def __init__(self, queue, name):
        self.queue = queue
        self.name = name
        self.jobs = []
        self.semaphore = Semaphore()

    def add_job(self, job):
        self.jobs.append(job)

    def get_jobs(self):
        return self.jobs

    def remove_job(self, job):
        self.jobs.remove(job)

    def locked(self):
        return self.semaphore.locked()

    def acquire(self):
        return self.semaphore.acquire()

    def release(self):
        return self.semaphore.release()
コード例 #4
0
ファイル: base.py プロジェクト: Aaron1011/domorereps-android
class GeventedHTTPTransport(HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests=100):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')
        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split('+', 1)[-1]

    def send(self, data, headers):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(super(GeventedHTTPTransport, self).send, data, headers).link(self._done, self)

    def _done(self, *args):
        self._lock.release()
コード例 #5
0
ファイル: base.py プロジェクト: JTCunning/raven-python
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):

    scheme = ["gevent+http", "gevent+https"]

    def __init__(self, parsed_url, maximum_outstanding_requests=100):
        if not has_gevent:
            raise ImportError("GeventedHTTPTransport requires gevent.")
        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split("+", 1)[-1]

    def async_send(self, data, headers, success_cb, failure_cb):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(super(GeventedHTTPTransport, self).send, data, headers).link(
            lambda x: self._done(x, success_cb, failure_cb)
        )

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.value)
コード例 #6
0
ファイル: test__iwait.py プロジェクト: gevent/gevent
    def test_iwait_nogarbage(self):
        sem1 = Semaphore()
        sem2 = Semaphore()
        let = gevent.spawn(sem1.release)
        with gevent.iwait((sem1, sem2)) as iterator:
            self.assertEqual(sem1, next(iterator))
            self.assertEqual(sem2.linkcount(), 1)

        self.assertEqual(sem2.linkcount(), 0)
        let.get()
コード例 #7
0
ファイル: pool.py プロジェクト: Apolot/gevent
 def __init__(self, size=None, greenlet_class=None):
     if size is not None and size < 0:
         raise ValueError('size must not be negative: %r' % (size, ))
     Group.__init__(self)
     self.size = size
     if greenlet_class is not None:
         self.greenlet_class = greenlet_class
     if size is None:
         self._semaphore = DummySemaphore()
     else:
         self._semaphore = Semaphore(size)
コード例 #8
0
ファイル: test__semaphore.py プロジェクト: carriercomm/gevent
 def test_release_twice(self):
     s = Semaphore()
     result = []
     s.rawlink(lambda s: result.append('a'))
     s.release()
     s.rawlink(lambda s: result.append('b'))
     s.release()
     gevent.sleep(0.001)
     self.assertEqual(result, ['a', 'b'])
コード例 #9
0
ファイル: test__semaphore.py プロジェクト: gevent/gevent
 def test_release_twice(self):
     s = Semaphore()
     result = []
     s.rawlink(lambda s: result.append('a'))
     s.release()
     s.rawlink(lambda s: result.append('b'))
     s.release()
     gevent.sleep(0.001)
     # The order, though, is not guaranteed.
     self.assertEqual(sorted(result), ['a', 'b'])
コード例 #10
0
ファイル: pool.py プロジェクト: uschen/gevent3
class Pool(Group):
    def __init__(self, size=None, greenlet_class=None):
        if size is not None and size < 0:
            raise ValueError("size must not be negative: %r" % (size,))
        Group.__init__(self)
        self.size = size
        if greenlet_class is not None:
            self.greenlet_class = greenlet_class
        if size is None:
            self._semaphore = DummySemaphore()
        else:
            self._semaphore = Semaphore(size)

    def wait_available(self):
        self._semaphore.wait()

    def full(self):
        return self.free_count() <= 0

    def free_count(self):
        if self.size is None:
            return 1
        return max(0, self.size - len(self))

    def add(self, greenlet):
        self._semaphore.acquire()
        try:
            Group.add(self, greenlet)
        except:
            self._semaphore.release()
            raise

    def _discard(self, greenlet):
        Group._discard(self, greenlet)
        self._semaphore.release()
コード例 #11
0
ファイル: base.py プロジェクト: fr4c74l/diggems
class GeventConnectionPool(psycopg2.pool.AbstractConnectionPool):
    def __init__(self, minconn, maxconn, *args, **kwargs):
        self.semaphore = Semaphore(maxconn)
        psycopg2.pool.AbstractConnectionPool.__init__(self, minconn, maxconn, *args, **kwargs)

    def getconn(self, *args, **kwargs):
        self.semaphore.acquire()
        return self._getconn(*args, **kwargs)

    def putconn(self, *args, **kwargs):
        self._putconn(*args, **kwargs)
        self.semaphore.release()

    # Not sure what to do about this one...
    closeall = psycopg2.pool.AbstractConnectionPool._closeall
コード例 #12
0
ファイル: gevent.py プロジェクト: CGenie/raven-python
    def __init__(self, parsed_url, maximum_outstanding_requests=100, *args, **kwargs):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')

        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url, *args, **kwargs)
コード例 #13
0
 def __init__(self, logger, zkservers, svc_name, inst,
             watchers={}, zpostfix="", freq=10):
     gevent.Greenlet.__init__(self)
     self._svc_name = svc_name
     self._inst = inst
     self._zk_server = zkservers
     # initialize logging and other stuff
     if logger is None:
         logging.basicConfig()
         self._logger = logging
     else:
         self._logger = logger
     self._conn_state = None
     self._sandesh_connection_info_update(status='INIT', message='')
     self._zkservers = zkservers
     self._zk = None
     self._pubinfo = None
     self._publock = Semaphore()
     self._watchers = watchers
     self._wchildren = {}
     self._pendingcb = set()
     self._zpostfix = zpostfix
     self._basepath = "/analytics-discovery-" + self._zpostfix
     self._reconnect = None
     self._freq = freq
コード例 #14
0
    def __init__(self, url):
        self.sock_file = get_data(url)
        gevent.spawn(self.start_reading_bytes)
        self.file_buffer = array.array('c')
        self.file_buffer_ptr = 0
#         gevent.spawn(self.start_reading_dummy_bytes)
        self.evt = Semaphore(value=0)
コード例 #15
0
    def __init__(self, uplink=None, downlink=None):
        self.log = logging.getLogger("{module}.{name}".format(
            module=self.__class__.__module__, name=self.__class__.__name__))
        
        self.downlink = downlink
        self.uplink = uplink
        self.recv_callback = None

        self.context = zmq.Context()
        self.poller = zmq.Poller()

        self.ul_socket = self.context.socket(zmq.SUB) # one SUB socket for uplink communication over topics
        if sys.version_info.major >= 3:
            self.ul_socket.setsockopt_string(zmq.SUBSCRIBE,  "NEW_NODE")
            self.ul_socket.setsockopt_string(zmq.SUBSCRIBE,  "NODE_EXIT")
        else:
            self.ul_socket.setsockopt(zmq.SUBSCRIBE,  "NEW_NODE")
            self.ul_socket.setsockopt(zmq.SUBSCRIBE,  "NODE_EXIT")

        self.downlinkSocketLock = Semaphore(value=1)
        self.dl_socket = self.context.socket(zmq.PUB) # one PUB socket for downlink communication over topics

        #register UL socket in poller
        self.poller.register(self.ul_socket, zmq.POLLIN)

        self.importedPbClasses = {}
コード例 #16
0
ファイル: dynamic.py プロジェクト: Tefx/Brick
 def __init__(self, provider, n, **kwargs):
     super(LimitEngine, self).__init__(provider, **kwargs)
     self.conf = self.provider.configurations()[0]
     self.n = n
     self.services = []
     self.num_unscheduled = -1
     self.le_lock = Semaphore()
コード例 #17
0
    def __init__(self, rabbit_server, rabbit_port,
                 rabbit_user, rabbit_password,
                 notification_level, ironic_notif_mgr_obj, **kwargs):
        self._rabbit_port = rabbit_port
        self._rabbit_user = rabbit_user
        self._rabbit_password = rabbit_password
        self._rabbit_hosts = self._parse_rabbit_hosts(rabbit_server)
        self._rabbit_ip = self._rabbit_hosts[0]["host"]
        self._notification_level = notification_level
        self._ironic_notification_manager = ironic_notif_mgr_obj
        self._conn_lock = Semaphore()

        # Register a handler for SIGTERM so that we can release the lock
        # Without it, it can take several minutes before new master is elected
        # If any app using this wants to register their own sigterm handler,
        # then we will have to modify this function to perhaps take an argument
        # gevent.signal(signal.SIGTERM, self.sigterm_handler)

        self._url = "amqp://%s:%s@%s:%s/" % (self._rabbit_user,
                                             self._rabbit_password,
                                             self._rabbit_ip,
                                             self._rabbit_port)
        msg = "Initializing RabbitMQ connection, urls %s" % self._url
        # self._conn_state = ConnectionStatus.INIT
        self._conn = kombu.Connection(self._url)
        self._exchange = self._set_up_exchange()
        self._queues = []
        self._queues = self._set_up_queues(self._notification_level)
        if not self._queues:
            exit()
コード例 #18
0
    def __init__(
        self,
        rabbit_ip,
        rabbit_port,
        rabbit_user,
        rabbit_password,
        rabbit_vhost,
        rabbit_ha_mode,
        q_name,
        subscribe_cb,
        logger,
    ):
        self._rabbit_ip = rabbit_ip
        self._rabbit_port = rabbit_port
        self._rabbit_user = rabbit_user
        self._rabbit_password = rabbit_password
        self._rabbit_vhost = rabbit_vhost
        self._subscribe_cb = subscribe_cb
        self._logger = logger
        self._publish_queue = Queue()
        self._conn_lock = Semaphore()

        self.obj_upd_exchange = kombu.Exchange("vnc_config.object-update", "fanout", durable=False)

        # Register a handler for SIGTERM so that we can release the lock
        # Without it, it can take several minutes before new master is elected
        # If any app using this wants to register their own sigterm handler,
        # then we will have to modify this function to perhaps take an argument
        gevent.signal(signal.SIGTERM, self.sigterm_handler)
コード例 #19
0
ファイル: worker.py プロジェクト: Tefx/Brick
class Puppet(object):
    def __init__(self):
        self.results = {}
        self.status = ("Idle", None)
        self.worker = None
        self.operator = None
        self.lock = Semaphore()

    def receive_info(self):
        while True:
            (info_id, info) = self.task_queue.get()
            if info_id > 0:
                self.results[info_id].set(info)
            elif info_id == -1:
                self.status = info

    def submit_task(self, tid, task_info):
        self.results[tid] = AsyncResult()
        self.lock.acquire()
        self.task_queue.put((tid, task_info))
        self.lock.release()

    def fetch_result(self, tid):
        res = self.results[tid].get()
        del self.results[tid]
        return res

    def hire_worker(self):
        tq_worker, self.task_queue = gipc.pipe(duplex=True)
        self.worker = gipc.start_process(target=process_worker, args=(tq_worker,))
        self.operator = gevent.spawn(self.receive_info)

    def fire_worker(self):
        self.worker.terminate()
        self.operator.kill()

    def get_attr(self, attr):
        if attr == "cpu":
            return psutil.cpu_percent(interval=None)
        elif attr == "memory":
            return psutil.virtual_memory().percent
        else:
            return getattr(self, attr, "No such attribute")

    def current_tasks(self):
        return self.results.keys()
コード例 #20
0
ファイル: net_gevent.py プロジェクト: thingbreaker/productAPI
    def __init__(self, parent, io_loop=None):
        self._parent = parent
        self._closing = False
        self._user_queries = { }
        self._cursor_cache = { }

        self._write_mutex = Semaphore()
        self._socket = None
コード例 #21
0
ファイル: base.py プロジェクト: Aaron1011/domorereps-android
    def __init__(self, parsed_url, maximum_outstanding_requests=100):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')
        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split('+', 1)[-1]
コード例 #22
0
def test_renamed_label_refresh(db, default_account, thread, message,
                               imapuid, folder, mock_imapclient, monkeypatch):
    # Check that imapuids see their labels refreshed after running
    # the LabelRenameHandler.
    msg_uid = imapuid.msg_uid
    uid_dict = {msg_uid: GmailFlags((), ('stale label',), ('23',))}

    update_metadata(default_account.id, folder.id, folder.canonical_name,
                    uid_dict, db.session)

    new_flags = {msg_uid: {'FLAGS': ('\\Seen',), 'X-GM-LABELS': ('new label',),
                           'MODSEQ': ('23',)}}
    mock_imapclient._data['[Gmail]/All mail'] = new_flags

    mock_imapclient.add_folder_data(folder.name, new_flags)

    monkeypatch.setattr(MockIMAPClient, 'search',
                        lambda x, y: [msg_uid])

    semaphore = Semaphore(value=1)

    rename_handler = LabelRenameHandler(default_account.id,
                                        default_account.namespace.id,
                                        'new label', semaphore)

    # Acquire the semaphore to check that LabelRenameHandlers block if
    # the semaphore is in-use.
    semaphore.acquire()
    rename_handler.start()

    # Wait 10 secs and check that the data hasn't changed.
    gevent.sleep(10)

    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == 'stale label'
    semaphore.release()
    rename_handler.join()

    db.session.refresh(imapuid)
    # Now check that the label got updated.
    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == 'new label'
コード例 #23
0
ファイル: server.py プロジェクト: dbrehmer/Knowself
class DatagramServer(BaseServer):
    """A UDP server"""

    reuse_addr = DEFAULT_REUSE_ADDR

    def __init__(self, *args, **kwargs):
        # The raw (non-gevent) socket, if possible
        self._socket = None
        BaseServer.__init__(self, *args, **kwargs)
        from gevent.lock import Semaphore
        self._writelock = Semaphore()

    def init_socket(self):
        if not hasattr(self, 'socket'):
            # FIXME: clean up the socket lifetime
            # pylint:disable=attribute-defined-outside-init
            self.socket = self.get_listener(self.address, self.family)
            self.address = self.socket.getsockname()
        self._socket = self.socket
        try:
            self._socket = self._socket._sock
        except AttributeError:
            pass

    @classmethod
    def get_listener(cls, address, family=None):
        return _udp_socket(address, reuse_addr=cls.reuse_addr, family=family)

    def do_read(self):
        try:
            data, address = self._socket.recvfrom(8192)
        except _socket.error as err:
            if err.args[0] == EWOULDBLOCK:
                return
            raise
        return data, address

    def sendto(self, *args):
        self._writelock.acquire()
        try:
            self.socket.sendto(*args)
        finally:
            self._writelock.release()
コード例 #24
0
ファイル: __init__.py プロジェクト: tf198/wishbone
    def __init__(self, actor_config, location="", rules={}):
        Actor.__init__(self, actor_config)

        self.pool.createQueue("inbox")
        self.pool.createQueue("nomatch")
        self.registerConsumer(self.consume, "inbox")

        self.__active_rules = {}
        self.match = MatchRules()
        self.rule_lock = Semaphore()
コード例 #25
0
ファイル: base.py プロジェクト: Tefx/Brick
 def __init__(self, s_id, conf):
     self.s_id = s_id
     self.conf = conf
     self.start_time = None
     self.finish_time = None
     self.puppet = None
     self.lock = Semaphore()
     self.queue = set()
     self._status = "Unknown"
     self.started = False
コード例 #26
0
ファイル: server.py プロジェクト: dustyneuron/gevent
class DatagramServer(BaseServer):
    """A UDP server"""

    reuse_addr = DEFAULT_REUSE_ADDR

    def __init__(self, *args, **kwargs):
        BaseServer.__init__(self, *args, **kwargs)
        from gevent.lock import Semaphore
        self._writelock = Semaphore()

    def init_socket(self):
        if not hasattr(self, 'socket'):
            self.socket = self.get_listener(self.address, self.family)
            self.address = self.socket.getsockname()
        self._socket = self.socket
        try:
            self._socket = self._socket._sock
        except AttributeError:
            pass

    @classmethod
    def get_listener(self, address, family=None):
        return _udp_socket(address, reuse_addr=self.reuse_addr, family=family)

    def do_read(self):
        try:
            data, address = self._socket.recvfrom(8192)
        except _socket.error:
            err = sys.exc_info()[1]
            if err.args[0] == EWOULDBLOCK:
                return
            raise
        return data, address

    def sendto(self, *args):
        self._writelock.acquire()
        try:
            self.socket.sendto(*args)
        finally:
            self._writelock.release()
コード例 #27
0
ファイル: dynamic.py プロジェクト: Tefx/Brick
class LimitEngine(EngineBase):
    def __init__(self, provider, n, **kwargs):
        super(LimitEngine, self).__init__(provider, **kwargs)
        self.conf = self.provider.configurations()[0]
        self.n = n
        self.services = []
        self.num_unscheduled = -1
        self.le_lock = Semaphore()

    def before_eval(self):
        self.num_unscheduled = len(self.dag)

    def after_eval(self):
        for s in self.services:
            self.provider.stop_service(s)
        self.services = []

    def which_service(self, task):
        self.num_unscheduled -= 1
        for s in self.services:
            if len(s.tasks) == 0:
                return s
        if len(self.services) < self.n:
            s = self.provider.start_service(len(self.services) + 1, self.conf)
            self.services.append(s)
        else:
            s = min(self.services, key=lambda x: len(x.tasks))
        return s

    def after_task(self, task, service):
        self.le_lock.acquire()
        if self.num_unscheduled == 0 and len(service.tasks) == 0:
            self.provider.stop_service(service)
            for s in self.services:
                if len(s.tasks) == 0:
                    self.provider.stop_service(s)
        self.le_lock.release()

    def current_services(self):
        return self.services
コード例 #28
0
ファイル: test.py プロジェクト: sublee/lets
def _test_atomic():
    # NOTE: Nested context by comma is not available in Python 2.6.
    # o -- No gevent.
    with lets.atomic():
        1 + 2 + 3
    # x -- gevent.sleep()
    with pytest.raises(AssertionError):
        with lets.atomic():
            gevent.sleep(0.1)
    # x -- gevent.sleep() with 0 seconds.
    with pytest.raises(AssertionError):
        with lets.atomic():
            gevent.sleep(0)
    # o -- Greenlet.spawn()
    with lets.atomic():
        gevent.spawn(gevent.sleep, 0.1)
    # x -- Greenlet.join()
    with pytest.raises(AssertionError):
        with lets.atomic():
            g = gevent.spawn(gevent.sleep, 0.1)
            g.join()
    # x -- Greenlet.get()
    with pytest.raises(AssertionError):
        with lets.atomic():
            g = gevent.spawn(gevent.sleep, 0.1)
            g.get()
    # x -- gevent.joinall()
    with pytest.raises(AssertionError):
        with lets.atomic():
            g = gevent.spawn(gevent.sleep, 0.1)
            gevent.joinall([g])
    # o -- Event.set(), AsyncResult.set()
    with lets.atomic():
        Event().set()
        AsyncResult().set()
    # x -- Event.wait()
    with pytest.raises(AssertionError):
        with lets.atomic():
            Event().wait()
    # x -- Event.wait()
    with pytest.raises(AssertionError):
        with lets.atomic():
            AsyncResult().wait()
    # x -- Channel.put()
    with pytest.raises(AssertionError):
        with lets.atomic():
            ch = Channel()
            ch.put(123)
    # o -- First Semaphore.acquire()
    with lets.atomic():
        lock = Semaphore()
        lock.acquire()
    # x -- Second Semaphore.acquire()
    with pytest.raises(AssertionError):
        with lets.atomic():
            lock = Semaphore()
            lock.acquire()
            lock.acquire()
    # Back to normal.
    gevent.sleep(1)
コード例 #29
0
ファイル: models.py プロジェクト: MoroGasper/client
    def __init__(self, **kwargs):
        self.account = self # needed for InputFunctions.solve_* functions

        self.multi_account = False

        self.lock = Semaphore()
        self.check_pool = VariableSizePool(size=self.max_check_tasks)
        self.download_pool = VariableSizePool(size=self.max_download_tasks)
        self.search_pool = VariableSizePool(size=10)
        self.reset()

        for k, v in kwargs.iteritems():
            setattr(self, k, v)
コード例 #30
0
ファイル: gevent.py プロジェクト: lxyu/raven-python
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests=100,
                 oneway=False):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')
        self._lock = Semaphore(maximum_outstanding_requests)
        self._oneway = oneway == 'true'

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split('+', 1)[-1]

    def async_send(self, data, headers, success_cb, failure_cb):
        """
        Spawn an async request to a remote webserver.
        """
        if not self._oneway:
            self._lock.acquire()
            return gevent.spawn(
                super(GeventedHTTPTransport, self).send, data, headers
            ).link(lambda x: self._done(x, success_cb, failure_cb))
        else:
            req = urllib2.Request(self._url, headers=headers)
            return urlopen(url=req,
                           data=data,
                           timeout=self.timeout,
                           verify_ssl=self.verify_ssl,
                           ca_certs=self.ca_certs)

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.exception)
コード例 #31
0
 def __init__(self, *args, **kwargs):
     BaseServer.__init__(self, *args, **kwargs)
     from gevent.lock import Semaphore
     self._writelock = Semaphore()
コード例 #32
0
ファイル: rpc_server.py プロジェクト: zhouqiang-cl/wloki
 def __init__(self, sck, server):
     self._sck = sck
     self._rlock = Semaphore()
     self._wlock = Semaphore()
     self._server = server
     self._living_controllers = {}
コード例 #33
0
class AnalyticsDiscovery(gevent.Greenlet):
    def _sandesh_connection_info_update(self, status, message):

        new_conn_state = getattr(ConnectionStatus, status)
        ConnectionState.update(conn_type=ConnectionType.ZOOKEEPER,
                               name=self._svc_name,
                               status=new_conn_state,
                               server_addrs=self._zk_server.split(','),
                               message=message)

        if (self._conn_state and self._conn_state != ConnectionStatus.DOWN
                and new_conn_state == ConnectionStatus.DOWN):
            msg = 'Connection to Zookeeper down: %s' % (message)
            self._logger.error(msg)
        if (self._conn_state and self._conn_state != new_conn_state
                and new_conn_state == ConnectionStatus.UP):
            msg = 'Connection to Zookeeper ESTABLISHED'
            self._logger.error(msg)

        self._conn_state = new_conn_state

    # end _sandesh_connection_info_update

    def _zk_listen(self, state):
        self._logger.error("Analytics Discovery listen %s" % str(state))
        if state == KazooState.CONNECTED:
            self._sandesh_connection_info_update(
                status='UP', message='Connection to Zookeeper re-established')
            self._logger.error("Analytics Discovery to publish %s" %
                               str(self._pubinfo))
            self._reconnect = True
        elif state == KazooState.LOST:
            self._logger.error("Analytics Discovery connection LOST")
            # Lost the session with ZooKeeper Server
            # Best of option we have is to exit the process and restart all
            # over again
            self._sandesh_connection_info_update(
                status='DOWN', message='Connection to Zookeeper lost')
            os._exit(2)
        elif state == KazooState.SUSPENDED:
            self._logger.error("Analytics Discovery connection SUSPENDED")
            # Update connection info
            self._sandesh_connection_info_update(
                status='INIT',
                message='Connection to zookeeper lost. Retrying')

    def _zk_datawatch(self, watcher, child, data, stat, event="unknown"):
        self._logger.error(\
                "Analytics Discovery %s ChildData : child %s, data %s, event %s" % \
                (watcher, child, data, event))
        if data:
            data_dict = json.loads(data)
            self._wchildren[watcher][child] = OrderedDict(
                sorted(data_dict.items()))
        else:
            if child in self._wchildren[watcher]:
                del self._wchildren[watcher][child]
        if self._data_watchers[watcher]:
            self._pendingcb.add(watcher)

    def _zk_watcher(self, watcher, children):
        self._logger.error("Analytics Discovery Watcher %s Children %s" %
                           (watcher, children))
        self._reconnect = True

    def __init__(self,
                 logger,
                 zkservers,
                 svc_name,
                 inst,
                 data_watchers={},
                 child_watchers={},
                 zpostfix="",
                 freq=10):
        gevent.Greenlet.__init__(self)
        self._svc_name = svc_name
        self._inst = inst
        self._zk_server = zkservers
        # initialize logging and other stuff
        if logger is None:
            logging.basicConfig()
            self._logger = logging
        else:
            self._logger = logger
        self._conn_state = None
        self._sandesh_connection_info_update(
            status='INIT', message='Connection to Zookeeper initialized')
        self._zkservers = zkservers
        self._zk = None
        self._pubinfo = None
        self._publock = Semaphore()
        self._data_watchers = data_watchers
        self._child_watchers = child_watchers
        self._wchildren = {}
        self._pendingcb = set()
        self._zpostfix = zpostfix
        self._basepath = "/analytics-discovery-" + self._zpostfix
        self._reconnect = None
        self._freq = freq

    def publish(self, pubinfo):

        # This function can be called concurrently by the main AlarmDiscovery
        # processing loop as well as by clients.
        # It is NOT re-entrant
        self._publock.acquire()

        self._pubinfo = pubinfo
        if self._conn_state == ConnectionStatus.UP:
            try:
                self._logger.error("ensure %s" %
                                   (self._basepath + "/" + self._svc_name))
                self._logger.error("zk state %s (%s)" %
                                   (self._zk.state, self._zk.client_state))
                self._zk.ensure_path(self._basepath + "/" + self._svc_name)
                self._logger.error("check for %s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst))
                if pubinfo is not None:
                    if self._zk.exists("%s/%s/%s" % \
                            (self._basepath, self._svc_name, self._inst)):
                        self._zk.set("%s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst),
                                self._pubinfo)
                    else:
                        self._zk.create("%s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst),
                                self._pubinfo, ephemeral=True)
                else:
                    if self._zk.exists("%s/%s/%s" % \
                            (self._basepath, self._svc_name, self._inst)):
                        self._logger.error("withdrawing published info!")
                        self._zk.delete("%s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst))

            except Exception as ex:
                template = "Exception {0} in AnalyticsDiscovery publish. Args:\n{1!r}"
                messag = template.format(type(ex).__name__, ex.args)
                self._logger.error("%s : traceback %s for %s info %s" % \
                        (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo)))
                self._sandesh_connection_info_update(
                    status='DOWN',
                    message='Reconnect to Zookeeper to handle exception')
                self._reconnect = True
        else:
            self._logger.error("Analytics Discovery cannot publish while down")
        self._publock.release()

    def _run(self):
        while True:
            self._logger.error("Analytics Discovery zk start")
            self._zk = KazooClient(hosts=self._zkservers)
            self._zk.add_listener(self._zk_listen)
            try:
                self._zk.start()
                while self._conn_state != ConnectionStatus.UP:
                    gevent.sleep(1)
                break
            except Exception as e:
                # Update connection info
                self._sandesh_connection_info_update(status='DOWN',
                                                     message=str(e))
                self._zk.remove_listener(self._zk_listen)
                try:
                    self._zk.stop()
                    self._zk.close()
                except Exception as ex:
                    template = "Exception {0} in AnalyticsDiscovery zk stop/close. Args:\n{1!r}"
                    messag = template.format(type(ex).__name__, ex.args)
                    self._logger.error("%s : traceback %s for %s" % \
                        (messag, traceback.format_exc(), self._svc_name))
                finally:
                    self._zk = None
                gevent.sleep(1)

        try:
            # Update connection info
            self._sandesh_connection_info_update(
                status='UP', message='Connection to Zookeeper established')
            self._reconnect = False
            # Done connecting to ZooKeeper

            for wk in self._data_watchers.keys():
                self._zk.ensure_path(self._basepath + "/" + wk)
                self._wchildren[wk] = {}
                self._zk.ChildrenWatch(self._basepath + "/" + wk,
                                       partial(self._zk_watcher, wk))
            for wk in self._child_watchers.keys():
                self._zk.ensure_path(self._basepath + "/" + wk)
                self._zk.ChildrenWatch(self._basepath + "/" + wk,
                                       self._child_watchers[wk])
            # Trigger the initial publish
            self._reconnect = True

            while True:
                try:
                    if not self._reconnect:
                        pending_list = list(self._pendingcb)
                        self._pendingcb = set()
                        for wk in pending_list:
                            if self._data_watchers[wk]:
                                self._data_watchers[wk](\
                                        sorted(self._wchildren[wk].values()))

                    # If a reconnect happens during processing, don't lose it
                    while self._reconnect:
                        self._logger.error("Analytics Discovery %s reconnect" \
                                % self._svc_name)
                        self._reconnect = False
                        self._pendingcb = set()
                        self.publish(self._pubinfo)

                        for wk in self._data_watchers.keys():
                            self._zk.ensure_path(self._basepath + "/" + wk)
                            children = self._zk.get_children(self._basepath +
                                                             "/" + wk)

                            old_children = set(self._wchildren[wk].keys())
                            new_children = set(children)

                            # Remove contents for the children who are gone
                            # (DO NOT remove the watch)
                            for elem in old_children - new_children:
                                del self._wchildren[wk][elem]

                            # Overwrite existing children, or create new ones
                            for elem in new_children:
                                # Create a watch for new children
                                if elem not in self._wchildren[wk]:
                                    self._zk.DataWatch(self._basepath + "/" + \
                                            wk + "/" + elem,
                                            partial(self._zk_datawatch, wk, elem))

                                data_str, _ = self._zk.get(\
                                        self._basepath + "/" + wk + "/" + elem)
                                data_dict = json.loads(data_str)
                                self._wchildren[wk][elem] = \
                                        OrderedDict(sorted(data_dict.items()))

                                self._logger.error(\
                                    "Analytics Discovery %s ChildData : child %s, data %s, event %s" % \
                                    (wk, elem, self._wchildren[wk][elem], "GET"))
                            if self._data_watchers[wk]:
                                self._data_watchers[wk](sorted(
                                    self._wchildren[wk].values()))

                    gevent.sleep(self._freq)
                except gevent.GreenletExit:
                    self._logger.error("Exiting AnalyticsDiscovery for %s" % \
                            self._svc_name)
                    self._zk.remove_listener(self._zk_listen)
                    gevent.sleep(1)
                    try:
                        self._zk.stop()
                    except:
                        self._logger.error("Stopping kazooclient failed")
                    else:
                        self._logger.error("Stopping kazooclient successful")
                    try:
                        self._zk.close()
                    except:
                        self._logger.error("Closing kazooclient failed")
                    else:
                        self._logger.error("Closing kazooclient successful")
                    break

                except Exception as ex:
                    template = "Exception {0} in AnalyticsDiscovery reconnect. Args:\n{1!r}"
                    messag = template.format(type(ex).__name__, ex.args)
                    self._logger.error("%s : traceback %s for %s info %s" % \
                        (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo)))
                    self._reconnect = True

        except Exception as ex:
            template = "Exception {0} in AnalyticsDiscovery run. Args:\n{1!r}"
            messag = template.format(type(ex).__name__, ex.args)
            self._logger.error("%s : traceback %s for %s info %s" % \
                    (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo)))
            raise SystemExit
コード例 #34
0
class ScpSever():
    def __init__(self,conn):
        self.conn = conn
        self.closed = False
        self.connerr = None
        self.conn_mutex = RLock()
        self.conn_cond = Semaphore(0)
        
    def read(self,size):
        conn, err = self.acquire_conn()
        if err: #conn is closed
            return '',err
        data,err = conn.read(size)
        if err:
            #freeze, waiting for reuse
            conn.freeze()
            self.connerr = err
        return data, None


    def write(self,data):
        conn, err = self.acquire_conn()
        if err: #conn is closed
            return '',err
        err = self.conn.write(data)
        if err:
            #freeze, waiting for reuse
            conn.freeze()
            self.connerr = err
        return None
    
    @with_goto
    def close(self):
        self.conn_mutex.acquire()
        if self.closed:
            goto .end
        self.conn.close()
        self.closed = True
        self.connerr = error
        label .end
        self.conn_cond.release()
        self.conn_mutex.release()
        return self.connerr

    #超时计数
    def _star_wait(self):
        reuse_timeout = int(config['listen']['reuse_time'])
        self.time_task = Timer(reuse_timeout,self.close)
        self.time_task.start()

    def _stop_wait(self):
        self.time_task.cancel()

    def _cond_wait(self):
        self.conn_mutex.release()
        self.conn_cond.acquire()
        self.conn_mutex.acquire()

    def acquire_conn(self):
        self.conn_mutex.acquire()
        conn = None
        connerr = None
        while True:
            if self.closed:
                connerr = self.connerr
                break
            elif self.connerr:
                self._star_wait()
                self._cond_wait()
                self._stop_wait()
            else:
                conn = self.conn
                break
        self.conn_mutex.release()
        return conn, connerr

    @with_goto
    def replace_conn(self, conn):
        self.conn_mutex.acquire()
        ret = False
        if self.closed:
            goto .end
        #close old conn
        self.conn.close()
        #set new status
        self.conn = conn
        self.connerr = None
        ret = True
        label .end
        self.conn_cond.release()
        self.conn_mutex.release()
        return ret
コード例 #35
0
ファイル: rotkehlchen.py プロジェクト: resslerruntime/rotki
class Rotkehlchen():
    def __init__(self, args: argparse.Namespace) -> None:
        """Initialize the Rotkehlchen object

        May Raise:
        - SystemPermissionError if the given data directory's permissions
        are not correct.
        """
        self.lock = Semaphore()
        self.lock.acquire()

        # Can also be None after unlock if premium credentials did not
        # authenticate or premium server temporarily offline
        self.premium: Optional[Premium] = None
        self.user_is_logged_in: bool = False
        configure_logging(args)

        self.sleep_secs = args.sleep_secs
        if args.data_dir is None:
            self.data_dir = default_data_directory()
        else:
            self.data_dir = Path(args.data_dir)

        if not os.access(self.data_dir, os.W_OK | os.R_OK):
            raise SystemPermissionError(
                f'The given data directory {self.data_dir} is not readable or writable',
            )
        self.args = args
        self.msg_aggregator = MessagesAggregator()
        self.greenlet_manager = GreenletManager(
            msg_aggregator=self.msg_aggregator)
        self.exchange_manager = ExchangeManager(
            msg_aggregator=self.msg_aggregator)
        # Initialize the AssetResolver singleton
        AssetResolver(data_directory=self.data_dir)
        self.data = DataHandler(self.data_dir, self.msg_aggregator)
        self.cryptocompare = Cryptocompare(data_directory=self.data_dir,
                                           database=None)
        self.coingecko = Coingecko()
        self.icon_manager = IconManager(data_dir=self.data_dir,
                                        coingecko=self.coingecko)
        self.greenlet_manager.spawn_and_track(
            after_seconds=None,
            task_name='periodically_query_icons_until_all_cached',
            method=self.icon_manager.periodically_query_icons_until_all_cached,
            batch_size=ICONS_BATCH_SIZE,
            sleep_time_secs=ICONS_QUERY_SLEEP,
        )
        # Initialize the Inquirer singleton
        Inquirer(
            data_dir=self.data_dir,
            cryptocompare=self.cryptocompare,
            coingecko=self.coingecko,
        )
        # Keeps how many trades we have found per location. Used for free user limiting
        self.actions_per_location: Dict[str, Dict[Location, int]] = {
            'trade': defaultdict(int),
            'asset_movement': defaultdict(int),
        }

        self.lock.release()
        self.shutdown_event = gevent.event.Event()

    def reset_after_failed_account_creation_or_login(self) -> None:
        """If the account creation or login failed make sure that the Rotki instance is clear

        Tricky instances are when after either failed premium credentials or user refusal
        to sync premium databases we relogged in.
        """
        self.cryptocompare.db = None

    def unlock_user(
        self,
        user: str,
        password: str,
        create_new: bool,
        sync_approval: Literal['yes', 'no', 'unknown'],
        premium_credentials: Optional[PremiumCredentials],
        initial_settings: Optional[ModifiableDBSettings] = None,
    ) -> None:
        """Unlocks an existing user or creates a new one if `create_new` is True

        May raise:
        - PremiumAuthenticationError if the password can't unlock the database.
        - AuthenticationError if premium_credentials are given and are invalid
        or can't authenticate with the server
        - DBUpgradeError if the rotki DB version is newer than the software or
        there is a DB upgrade and there is an error.
        - SystemPermissionError if the directory or DB file can not be accessed
        """
        log.info(
            'Unlocking user',
            user=user,
            create_new=create_new,
            sync_approval=sync_approval,
            initial_settings=initial_settings,
        )

        # unlock or create the DB
        self.password = password
        self.user_directory = self.data.unlock(user, password, create_new,
                                               initial_settings)
        self.data_importer = DataImporter(db=self.data.db)
        self.last_data_upload_ts = self.data.db.get_last_data_upload_ts()
        self.premium_sync_manager = PremiumSyncManager(data=self.data,
                                                       password=password)
        # set the DB in the external services instances that need it
        self.cryptocompare.set_database(self.data.db)

        # Anything that was set above here has to be cleaned in case of failure in the next step
        # by reset_after_failed_account_creation_or_login()
        try:
            self.premium = self.premium_sync_manager.try_premium_at_start(
                given_premium_credentials=premium_credentials,
                username=user,
                create_new=create_new,
                sync_approval=sync_approval,
            )
        except PremiumAuthenticationError:
            # Reraise it only if this is during the creation of a new account where
            # the premium credentials were given by the user
            if create_new:
                raise
            self.msg_aggregator.add_error(
                'Tried to synchronize the database from remote but the local password '
                'does not match the one the remote DB has. Please change the password '
                'to be the same as the password of the account you want to sync from ',
            )
            # else let's just continue. User signed in succesfully, but he just
            # has unauthenticable/invalid premium credentials remaining in his DB

        settings = self.get_settings()
        self.greenlet_manager.spawn_and_track(
            after_seconds=None,
            task_name='submit_usage_analytics',
            method=maybe_submit_usage_analytics,
            should_submit=settings.submit_usage_analytics,
        )
        self.etherscan = Etherscan(database=self.data.db,
                                   msg_aggregator=self.msg_aggregator)
        historical_data_start = settings.historical_data_start
        eth_rpc_endpoint = settings.eth_rpc_endpoint
        # Initialize the price historian singleton
        PriceHistorian(
            data_directory=self.data_dir,
            history_date_start=historical_data_start,
            cryptocompare=self.cryptocompare,
        )
        self.accountant = Accountant(
            db=self.data.db,
            user_directory=self.user_directory,
            msg_aggregator=self.msg_aggregator,
            create_csv=True,
        )

        # Initialize the rotkehlchen logger
        LoggingSettings(anonymized_logs=settings.anonymized_logs)
        exchange_credentials = self.data.db.get_exchange_credentials()
        self.exchange_manager.initialize_exchanges(
            exchange_credentials=exchange_credentials,
            database=self.data.db,
        )

        # Initialize blockchain querying modules
        ethereum_manager = EthereumManager(
            ethrpc_endpoint=eth_rpc_endpoint,
            etherscan=self.etherscan,
            database=self.data.db,
            msg_aggregator=self.msg_aggregator,
            greenlet_manager=self.greenlet_manager,
            connect_at_start=ETHEREUM_NODES_TO_CONNECT_AT_START,
        )
        Inquirer().inject_ethereum(ethereum_manager)
        self.chain_manager = ChainManager(
            blockchain_accounts=self.data.db.get_blockchain_accounts(),
            ethereum_manager=ethereum_manager,
            msg_aggregator=self.msg_aggregator,
            database=self.data.db,
            greenlet_manager=self.greenlet_manager,
            premium=self.premium,
            eth_modules=settings.active_modules,
        )
        self.trades_historian = TradesHistorian(
            user_directory=self.user_directory,
            db=self.data.db,
            msg_aggregator=self.msg_aggregator,
            exchange_manager=self.exchange_manager,
            chain_manager=self.chain_manager,
        )
        self.user_is_logged_in = True
        log.debug('User unlocking complete')

    def logout(self) -> None:
        if not self.user_is_logged_in:
            return

        user = self.data.username
        log.info(
            'Logging out user',
            user=user,
        )
        self.greenlet_manager.clear()
        del self.chain_manager
        self.exchange_manager.delete_all_exchanges()

        # Reset rotkehlchen logger to default
        LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS)

        del self.accountant
        del self.trades_historian
        del self.data_importer

        if self.premium is not None:
            del self.premium
        self.data.logout()
        self.password = ''
        self.cryptocompare.unset_database()

        # Make sure no messages leak to other user sessions
        self.msg_aggregator.consume_errors()
        self.msg_aggregator.consume_warnings()

        self.user_is_logged_in = False
        log.info(
            'User successfully logged out',
            user=user,
        )

    def set_premium_credentials(self, credentials: PremiumCredentials) -> None:
        """
        Sets the premium credentials for Rotki

        Raises PremiumAuthenticationError if the given key is rejected by the Rotkehlchen server
        """
        log.info('Setting new premium credentials')
        if self.premium is not None:
            self.premium.set_credentials(credentials)
        else:
            self.premium = premium_create_and_verify(credentials)

        self.data.db.set_rotkehlchen_premium(credentials)

    def delete_premium_credentials(self) -> Tuple[bool, str]:
        """Deletes the premium credentials for Rotki"""
        msg = ''

        success = self.data.db.del_rotkehlchen_premium()
        if success is False:
            msg = 'The database was unable to delete the Premium keys for the logged-in user'
        self.deactivate_premium_status()
        return success, msg

    def deactivate_premium_status(self) -> None:
        """Deactivate premium in the current session"""
        self.premium = None
        self.premium_sync_manager.premium = None
        self.chain_manager.deactivate_premium_status()

    def start(self) -> gevent.Greenlet:
        return gevent.spawn(self.main_loop)

    def main_loop(self) -> None:
        """Rotki main loop that fires often and manages many different tasks

        Each task remembers the last time it run successfully and know how often it
        should run. So each task manages itself.
        """
        # super hacky -- organize better when recurring tasks are implemented
        # https://github.com/rotki/rotki/issues/1106
        xpub_derivation_scheduled = False
        while self.shutdown_event.wait(MAIN_LOOP_SECS_DELAY) is not True:
            if self.user_is_logged_in:
                log.debug('Main loop start')
                self.premium_sync_manager.maybe_upload_data_to_server()
                if not xpub_derivation_scheduled:
                    # 1 minute in the app's startup try to derive new xpub addresses
                    self.greenlet_manager.spawn_and_track(
                        after_seconds=60.0,
                        task_name='Derive new xpub addresses',
                        method=XpubManager(
                            self.chain_manager).check_for_new_xpub_addresses,
                    )
                    xpub_derivation_scheduled = True
                log.debug('Main loop end')

    def get_blockchain_account_data(
        self,
        blockchain: SupportedBlockchain,
    ) -> Union[List[BlockchainAccountData], Dict[str, Any]]:
        account_data = self.data.db.get_blockchain_account_data(blockchain)
        if blockchain != SupportedBlockchain.BITCOIN:
            return account_data

        xpub_data = self.data.db.get_bitcoin_xpub_data()
        addresses_to_account_data = {x.address: x for x in account_data}
        address_to_xpub_mappings = self.data.db.get_addresses_to_xpub_mapping(
            list(addresses_to_account_data.keys()),  # type: ignore
        )

        xpub_mappings: Dict['XpubData', List[BlockchainAccountData]] = {}
        for address, xpub_entry in address_to_xpub_mappings.items():
            if xpub_entry not in xpub_mappings:
                xpub_mappings[xpub_entry] = []
            xpub_mappings[xpub_entry].append(
                addresses_to_account_data[address])

        data: Dict[str, Any] = {'standalone': [], 'xpubs': []}
        # Add xpub data
        for xpub_entry in xpub_data:
            data_entry = xpub_entry.serialize()
            addresses = xpub_mappings.get(xpub_entry, None)
            data_entry['addresses'] = addresses if addresses and len(
                addresses) != 0 else None
            data['xpubs'].append(data_entry)
        # Add standalone addresses
        for account in account_data:
            if account.address not in address_to_xpub_mappings:
                data['standalone'].append(account)

        return data

    def add_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        account_data: List[BlockchainAccountData],
    ) -> BlockchainBalancesUpdate:
        """Adds new blockchain accounts

        Adds the accounts to the blockchain instance and queries them to get the
        updated balances. Also adds them in the DB

        May raise:
        - EthSyncError from modify_blockchain_account
        - InputError if the given accounts list is empty.
        - TagConstraintError if any of the given account data contain unknown tags.
        - RemoteError if an external service such as Etherscan is queried and
          there is a problem with its query.
        """
        self.data.db.ensure_tags_exist(
            given_data=account_data,
            action='adding',
            data_type='blockchain accounts',
        )
        address_type = blockchain.get_address_type()
        updated_balances = self.chain_manager.add_blockchain_accounts(
            blockchain=blockchain,
            accounts=[address_type(entry.address) for entry in account_data],
        )
        self.data.db.add_blockchain_accounts(
            blockchain=blockchain,
            account_data=account_data,
        )

        return updated_balances

    def edit_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        account_data: List[BlockchainAccountData],
    ) -> None:
        """Edits blockchain accounts

        Edits blockchain account data for the given accounts

        May raise:
        - InputError if the given accounts list is empty or if
        any of the accounts to edit do not exist.
        - TagConstraintError if any of the given account data contain unknown tags.
        """
        # First check for validity of account data addresses
        if len(account_data) == 0:
            raise InputError(
                'Empty list of blockchain account data to edit was given')
        accounts = [x.address for x in account_data]
        unknown_accounts = set(accounts).difference(
            self.chain_manager.accounts.get(blockchain))
        if len(unknown_accounts) != 0:
            raise InputError(
                f'Tried to edit unknown {blockchain.value} '
                f'accounts {",".join(unknown_accounts)}', )

        self.data.db.ensure_tags_exist(
            given_data=account_data,
            action='editing',
            data_type='blockchain accounts',
        )

        # Finally edit the accounts
        self.data.db.edit_blockchain_accounts(
            blockchain=blockchain,
            account_data=account_data,
        )

        return None

    def remove_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        accounts: ListOfBlockchainAddresses,
    ) -> BlockchainBalancesUpdate:
        """Removes blockchain accounts

        Removes the accounts from the blockchain instance and queries them to get
        the updated balances. Also removes them from the DB

        May raise:
        - RemoteError if an external service such as Etherscan is queried and
          there is a problem with its query.
        - InputError if a non-existing account was given to remove
        """
        balances_update = self.chain_manager.remove_blockchain_accounts(
            blockchain=blockchain,
            accounts=accounts,
        )
        self.data.db.remove_blockchain_accounts(blockchain, accounts)
        return balances_update

    def process_history(
        self,
        start_ts: Timestamp,
        end_ts: Timestamp,
    ) -> Tuple[Dict[str, Any], str]:
        (
            error_or_empty,
            history,
            loan_history,
            asset_movements,
            eth_transactions,
            defi_events,
        ) = self.trades_historian.get_history(
            start_ts=start_ts,
            end_ts=end_ts,
            has_premium=self.premium is not None,
        )
        result = self.accountant.process_history(
            start_ts=start_ts,
            end_ts=end_ts,
            trade_history=history,
            loan_history=loan_history,
            asset_movements=asset_movements,
            eth_transactions=eth_transactions,
            defi_events=defi_events,
        )
        return result, error_or_empty

    @overload
    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['trade'],
        location_actions: List[Trade],
        all_actions: List[Trade],
    ) -> List[Trade]:
        ...

    @overload
    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['asset_movement'],
        location_actions: List[AssetMovement],
        all_actions: List[AssetMovement],
    ) -> List[AssetMovement]:
        ...

    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['trade', 'asset_movement'],
        location_actions: Union[List[Trade], List[AssetMovement]],
        all_actions: Union[List[Trade], List[AssetMovement]],
    ) -> Union[List[Trade], List[AssetMovement]]:
        """Take as many actions from location actions and add them to all actions as the limit permits

        Returns the modified (or not) all_actions
        """
        # If we are already at or above the limit return current actions disregarding this location
        actions_mapping = self.actions_per_location[action_type]
        current_num_actions = sum(x for _, x in actions_mapping.items())
        limit = LIMITS_MAPPING[action_type]
        if current_num_actions >= limit:
            return all_actions

        # Find out how many more actions can we return, and depending on that get
        # the number of actions from the location actions and add them to the total
        remaining_num_actions = limit - current_num_actions
        if remaining_num_actions < 0:
            remaining_num_actions = 0

        num_actions_to_take = min(len(location_actions), remaining_num_actions)

        actions_mapping[location] = num_actions_to_take
        all_actions.extend(
            location_actions[0:num_actions_to_take])  # type: ignore
        return all_actions

    def query_trades(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Optional[Location],
    ) -> List[Trade]:
        """Queries trades for the given location and time range.
        If no location is given then all external and all exchange trades are queried.

        If the user does not have premium then a trade limit is applied.

        May raise:
        - RemoteError: If there are problems connectingto any of the remote exchanges
        """
        if location is not None:
            trades = self.query_location_trades(from_ts, to_ts, location)
        else:
            trades = self.query_location_trades(from_ts, to_ts,
                                                Location.EXTERNAL)
            for name, exchange in self.exchange_manager.connected_exchanges.items(
            ):
                exchange_trades = exchange.query_trade_history(
                    start_ts=from_ts, end_ts=to_ts)
                if self.premium is None:
                    trades = self._apply_actions_limit(
                        location=deserialize_location(name),
                        action_type='trade',
                        location_actions=exchange_trades,
                        all_actions=trades,
                    )
                else:
                    trades.extend(exchange_trades)

        # return trades with most recent first
        trades.sort(key=lambda x: x.timestamp, reverse=True)
        return trades

    def query_location_trades(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Location,
    ) -> List[Trade]:
        # clear the trades queried for this location
        self.actions_per_location['trade'][location] = 0

        if location == Location.EXTERNAL:
            location_trades = self.data.db.get_trades(
                from_ts=from_ts,
                to_ts=to_ts,
                location=location,
            )
        else:
            # should only be an exchange
            exchange = self.exchange_manager.get(str(location))
            if not exchange:
                logger.warn(
                    f'Tried to query trades from {location} which is either not an '
                    f'exchange or not an exchange the user has connected to', )
                return []

            location_trades = exchange.query_trade_history(start_ts=from_ts,
                                                           end_ts=to_ts)

        trades: List[Trade] = []
        if self.premium is None:
            trades = self._apply_actions_limit(
                location=location,
                action_type='trade',
                location_actions=location_trades,
                all_actions=trades,
            )
        else:
            trades = location_trades

        return trades

    def query_balances(
        self,
        requested_save_data: bool = False,
        timestamp: Timestamp = None,
        ignore_cache: bool = False,
    ) -> Dict[str, Any]:
        """Query all balances rotkehlchen can see.

        If requested_save_data is True then the data are always saved in the DB,
        if it is False then data are saved if self.data.should_save_balances()
        is True.
        If timestamp is None then the current timestamp is used.
        If a timestamp is given then that is the time that the balances are going
        to be saved in the DB
        If ignore_cache is True then all underlying calls that have a cache ignore it

        Returns a dictionary with the queried balances.
        """
        log.info('query_balances called',
                 requested_save_data=requested_save_data)

        balances = {}
        problem_free = True
        for _, exchange in self.exchange_manager.connected_exchanges.items():
            exchange_balances, _ = exchange.query_balances(
                ignore_cache=ignore_cache)
            # If we got an error, disregard that exchange but make sure we don't save data
            if not isinstance(exchange_balances, dict):
                problem_free = False
            else:
                balances[exchange.name] = exchange_balances

        try:
            blockchain_result = self.chain_manager.query_balances(
                blockchain=None,
                force_token_detection=ignore_cache,
                ignore_cache=ignore_cache,
            )
            balances['blockchain'] = {
                asset: balance.to_dict()
                for asset, balance in blockchain_result.totals.items()
            }
        except (RemoteError, EthSyncError) as e:
            problem_free = False
            log.error(f'Querying blockchain balances failed due to: {str(e)}')

        balances = account_for_manually_tracked_balances(db=self.data.db,
                                                         balances=balances)

        combined = combine_stat_dicts([v for k, v in balances.items()])
        total_usd_per_location = [(k, dict_get_sumof(v, 'usd_value'))
                                  for k, v in balances.items()]

        # calculate net usd value
        net_usd = FVal(0)
        for _, v in combined.items():
            net_usd += FVal(v['usd_value'])

        stats: Dict[str, Any] = {
            'location': {},
            'net_usd': net_usd,
        }
        for entry in total_usd_per_location:
            name = entry[0]
            total = entry[1]
            if net_usd != FVal(0):
                percentage = (total / net_usd).to_percentage()
            else:
                percentage = '0%'
            stats['location'][name] = {
                'usd_value': total,
                'percentage_of_net_value': percentage,
            }

        for k, v in combined.items():
            if net_usd != FVal(0):
                percentage = (v['usd_value'] / net_usd).to_percentage()
            else:
                percentage = '0%'
            combined[k]['percentage_of_net_value'] = percentage

        result_dict = merge_dicts(combined, stats)

        allowed_to_save = requested_save_data or self.data.should_save_balances(
        )

        if problem_free and allowed_to_save:
            if not timestamp:
                timestamp = Timestamp(int(time.time()))
            self.data.save_balances_data(data=result_dict, timestamp=timestamp)
            log.debug('query_balances data saved')
        else:
            log.debug(
                'query_balances data not saved',
                allowed_to_save=allowed_to_save,
                problem_free=problem_free,
            )

        # After adding it to the saved file we can overlay additional data that
        # is not required to be saved in the history file
        try:
            details = self.accountant.events.details
            for asset, (tax_free_amount, average_buy_value) in details.items():
                if asset not in result_dict:
                    continue

                result_dict[asset]['tax_free_amount'] = tax_free_amount
                result_dict[asset]['average_buy_value'] = average_buy_value

                current_price = result_dict[asset]['usd_value'] / result_dict[
                    asset]['amount']
                if average_buy_value != FVal(0):
                    result_dict[asset]['percent_change'] = (
                        ((current_price - average_buy_value) /
                         average_buy_value) * 100)
                else:
                    result_dict[asset]['percent_change'] = 'INF'

        except AttributeError:
            pass

        return result_dict

    def _query_exchange_asset_movements(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        all_movements: List[AssetMovement],
        exchange: ExchangeInterface,
    ) -> List[AssetMovement]:
        location = deserialize_location(exchange.name)
        # clear the asset movements queried for this exchange
        self.actions_per_location['asset_movement'][location] = 0
        location_movements = exchange.query_deposits_withdrawals(
            start_ts=from_ts, end_ts=to_ts)

        movements: List[AssetMovement] = []
        if self.premium is None:
            movements = self._apply_actions_limit(
                location=location,
                action_type='asset_movement',
                location_actions=location_movements,
                all_actions=all_movements,
            )
        else:
            movements = location_movements

        return movements

    def query_asset_movements(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Optional[Location],
    ) -> List[AssetMovement]:
        """Queries AssetMovements for the given location and time range.

        If no location is given then all exchange asset movements are queried.
        If the user does not have premium then a limit is applied.
        May raise:
        - RemoteError: If there are problems connecting to any of the remote exchanges
        """
        movements: List[AssetMovement] = []
        if location is not None:
            exchange = self.exchange_manager.get(str(location))
            if not exchange:
                logger.warn(
                    f'Tried to query deposits/withdrawals from {location} which is either not an '
                    f'exchange or not an exchange the user has connected to', )
                return []
            movements = self._query_exchange_asset_movements(
                from_ts=from_ts,
                to_ts=to_ts,
                all_movements=movements,
                exchange=exchange,
            )
        else:
            for _, exchange in self.exchange_manager.connected_exchanges.items(
            ):
                movements = self._query_exchange_asset_movements(
                    from_ts=from_ts,
                    to_ts=to_ts,
                    all_movements=movements,
                    exchange=exchange,
                )

        # return movements with most recent first
        movements.sort(key=lambda x: x.timestamp, reverse=True)
        return movements

    def set_settings(self, settings: ModifiableDBSettings) -> Tuple[bool, str]:
        """Tries to set new settings. Returns True in success or False with message if error"""
        with self.lock:
            if settings.eth_rpc_endpoint is not None:
                result, msg = self.chain_manager.set_eth_rpc_endpoint(
                    settings.eth_rpc_endpoint)
                if not result:
                    return False, msg

            if settings.kraken_account_type is not None:
                kraken = self.exchange_manager.get('kraken')
                if kraken:
                    kraken.set_account_type(
                        settings.kraken_account_type)  # type: ignore

            self.data.db.set_settings(settings)
            return True, ''

    def get_settings(self) -> DBSettings:
        """Returns the db settings with a check whether premium is active or not"""
        db_settings = self.data.db.get_settings(
            have_premium=self.premium is not None)
        return db_settings

    def setup_exchange(
        self,
        name: str,
        api_key: ApiKey,
        api_secret: ApiSecret,
        passphrase: Optional[str] = None,
    ) -> Tuple[bool, str]:
        """
        Setup a new exchange with an api key and an api secret and optionally a passphrase

        By default the api keys are always validated unless validate is False.
        """
        is_success, msg = self.exchange_manager.setup_exchange(
            name=name,
            api_key=api_key,
            api_secret=api_secret,
            database=self.data.db,
            passphrase=passphrase,
        )

        if is_success:
            # Success, save the result in the DB
            self.data.db.add_exchange(name,
                                      api_key,
                                      api_secret,
                                      passphrase=passphrase)
        return is_success, msg

    def remove_exchange(self, name: str) -> Tuple[bool, str]:
        if not self.exchange_manager.has_exchange(name):
            return False, 'Exchange {} is not registered'.format(name)

        self.exchange_manager.delete_exchange(name)
        # Success, remove it also from the DB
        self.data.db.remove_exchange(name)
        self.data.db.delete_used_query_range_for_exchange(name)
        return True, ''

    def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]:
        """Query for frequently changing data"""
        result: Dict[str, Union[bool, Timestamp]] = {}

        if self.user_is_logged_in:
            result[
                'last_balance_save'] = self.data.db.get_last_balance_save_time(
                )
            result[
                'eth_node_connection'] = self.chain_manager.ethereum.web3_mapping.get(
                    NodeName.OWN, None) is not None  # noqa : E501
            result[
                'history_process_start_ts'] = self.accountant.started_processing_timestamp
            result[
                'history_process_current_ts'] = self.accountant.currently_processing_timestamp
            result['last_data_upload_ts'] = Timestamp(
                self.premium_sync_manager.last_data_upload_ts)  # noqa : E501
        return result

    def shutdown(self) -> None:
        self.logout()
        self.shutdown_event.set()
コード例 #36
0
ファイル: rotkehlchen.py プロジェクト: rkalis/rotki
class Rotkehlchen():
    def __init__(self, args: argparse.Namespace) -> None:
        """Initialize the Rotkehlchen object

        May Raise:
        - SystemPermissionError if the given data directory's permissions
        are not correct.
        """
        self.lock = Semaphore()
        self.lock.acquire()

        # Can also be None after unlock if premium credentials did not
        # authenticate or premium server temporarily offline
        self.premium: Optional[Premium] = None
        self.user_is_logged_in: bool = False
        configure_logging(args)

        self.sleep_secs = args.sleep_secs
        if args.data_dir is None:
            self.data_dir = default_data_directory()
        else:
            self.data_dir = Path(args.data_dir)

        if not os.access(self.data_dir, os.W_OK | os.R_OK):
            raise SystemPermissionError(
                f'The given data directory {self.data_dir} is not readable or writable',
            )
        self.main_loop_spawned = False
        self.args = args
        self.api_task_greenlets: List[gevent.Greenlet] = []
        self.msg_aggregator = MessagesAggregator()
        self.greenlet_manager = GreenletManager(
            msg_aggregator=self.msg_aggregator)
        self.exchange_manager = ExchangeManager(
            msg_aggregator=self.msg_aggregator)
        # Initialize the AssetResolver singleton
        AssetResolver(data_directory=self.data_dir)
        self.data = DataHandler(self.data_dir, self.msg_aggregator)
        self.cryptocompare = Cryptocompare(data_directory=self.data_dir,
                                           database=None)
        self.coingecko = Coingecko(data_directory=self.data_dir)
        self.icon_manager = IconManager(data_dir=self.data_dir,
                                        coingecko=self.coingecko)
        self.greenlet_manager.spawn_and_track(
            after_seconds=None,
            task_name='periodically_query_icons_until_all_cached',
            exception_is_error=False,
            method=self.icon_manager.periodically_query_icons_until_all_cached,
            batch_size=ICONS_BATCH_SIZE,
            sleep_time_secs=ICONS_QUERY_SLEEP,
        )
        # Initialize the Inquirer singleton
        Inquirer(
            data_dir=self.data_dir,
            cryptocompare=self.cryptocompare,
            coingecko=self.coingecko,
        )
        # Keeps how many trades we have found per location. Used for free user limiting
        self.actions_per_location: Dict[str, Dict[Location, int]] = {
            'trade': defaultdict(int),
            'asset_movement': defaultdict(int),
        }

        self.lock.release()
        self.task_manager: Optional[TaskManager] = None
        self.shutdown_event = gevent.event.Event()

    def reset_after_failed_account_creation_or_login(self) -> None:
        """If the account creation or login failed make sure that the Rotki instance is clear

        Tricky instances are when after either failed premium credentials or user refusal
        to sync premium databases we relogged in.
        """
        self.cryptocompare.db = None

    def unlock_user(
        self,
        user: str,
        password: str,
        create_new: bool,
        sync_approval: Literal['yes', 'no', 'unknown'],
        premium_credentials: Optional[PremiumCredentials],
        initial_settings: Optional[ModifiableDBSettings] = None,
    ) -> None:
        """Unlocks an existing user or creates a new one if `create_new` is True

        May raise:
        - PremiumAuthenticationError if the password can't unlock the database.
        - AuthenticationError if premium_credentials are given and are invalid
        or can't authenticate with the server
        - DBUpgradeError if the rotki DB version is newer than the software or
        there is a DB upgrade and there is an error.
        - SystemPermissionError if the directory or DB file can not be accessed
        """
        log.info(
            'Unlocking user',
            user=user,
            create_new=create_new,
            sync_approval=sync_approval,
            initial_settings=initial_settings,
        )

        # unlock or create the DB
        self.password = password
        self.user_directory = self.data.unlock(user, password, create_new,
                                               initial_settings)
        self.data_importer = DataImporter(db=self.data.db)
        self.last_data_upload_ts = self.data.db.get_last_data_upload_ts()
        self.premium_sync_manager = PremiumSyncManager(data=self.data,
                                                       password=password)
        # set the DB in the external services instances that need it
        self.cryptocompare.set_database(self.data.db)

        # Anything that was set above here has to be cleaned in case of failure in the next step
        # by reset_after_failed_account_creation_or_login()
        try:
            self.premium = self.premium_sync_manager.try_premium_at_start(
                given_premium_credentials=premium_credentials,
                username=user,
                create_new=create_new,
                sync_approval=sync_approval,
            )
        except PremiumAuthenticationError:
            # Reraise it only if this is during the creation of a new account where
            # the premium credentials were given by the user
            if create_new:
                raise
            self.msg_aggregator.add_warning(
                'Could not authenticate the Rotki premium API keys found in the DB.'
                ' Has your subscription expired?', )
            # else let's just continue. User signed in succesfully, but he just
            # has unauthenticable/invalid premium credentials remaining in his DB

        settings = self.get_settings()
        self.greenlet_manager.spawn_and_track(
            after_seconds=None,
            task_name='submit_usage_analytics',
            exception_is_error=False,
            method=maybe_submit_usage_analytics,
            should_submit=settings.submit_usage_analytics,
        )
        self.etherscan = Etherscan(database=self.data.db,
                                   msg_aggregator=self.msg_aggregator)
        self.beaconchain = BeaconChain(database=self.data.db,
                                       msg_aggregator=self.msg_aggregator)
        eth_rpc_endpoint = settings.eth_rpc_endpoint
        # Initialize the price historian singleton
        PriceHistorian(
            data_directory=self.data_dir,
            cryptocompare=self.cryptocompare,
            coingecko=self.coingecko,
        )
        PriceHistorian().set_oracles_order(settings.historical_price_oracles)

        self.accountant = Accountant(
            db=self.data.db,
            user_directory=self.user_directory,
            msg_aggregator=self.msg_aggregator,
            create_csv=True,
            premium=self.premium,
        )

        # Initialize the rotkehlchen logger
        LoggingSettings(anonymized_logs=settings.anonymized_logs)
        exchange_credentials = self.data.db.get_exchange_credentials()
        self.exchange_manager.initialize_exchanges(
            exchange_credentials=exchange_credentials,
            database=self.data.db,
        )

        # Initialize blockchain querying modules
        ethereum_manager = EthereumManager(
            ethrpc_endpoint=eth_rpc_endpoint,
            etherscan=self.etherscan,
            database=self.data.db,
            msg_aggregator=self.msg_aggregator,
            greenlet_manager=self.greenlet_manager,
            connect_at_start=ETHEREUM_NODES_TO_CONNECT_AT_START,
        )
        kusama_manager = SubstrateManager(
            chain=SubstrateChain.KUSAMA,
            msg_aggregator=self.msg_aggregator,
            greenlet_manager=self.greenlet_manager,
            connect_at_start=KUSAMA_NODES_TO_CONNECT_AT_START,
            connect_on_startup=self._connect_ksm_manager_on_startup(),
            own_rpc_endpoint=settings.ksm_rpc_endpoint,
        )

        Inquirer().inject_ethereum(ethereum_manager)
        Inquirer().set_oracles_order(settings.current_price_oracles)

        self.chain_manager = ChainManager(
            blockchain_accounts=self.data.db.get_blockchain_accounts(),
            ethereum_manager=ethereum_manager,
            kusama_manager=kusama_manager,
            msg_aggregator=self.msg_aggregator,
            database=self.data.db,
            greenlet_manager=self.greenlet_manager,
            premium=self.premium,
            eth_modules=settings.active_modules,
            data_directory=self.data_dir,
            beaconchain=self.beaconchain,
            btc_derivation_gap_limit=settings.btc_derivation_gap_limit,
        )
        self.events_historian = EventsHistorian(
            user_directory=self.user_directory,
            db=self.data.db,
            msg_aggregator=self.msg_aggregator,
            exchange_manager=self.exchange_manager,
            chain_manager=self.chain_manager,
        )
        self.task_manager = TaskManager(
            max_tasks_num=DEFAULT_MAX_TASKS_NUM,
            greenlet_manager=self.greenlet_manager,
            api_task_greenlets=self.api_task_greenlets,
            database=self.data.db,
            cryptocompare=self.cryptocompare,
            premium_sync_manager=self.premium_sync_manager,
            chain_manager=self.chain_manager,
            exchange_manager=self.exchange_manager,
        )
        self.user_is_logged_in = True
        log.debug('User unlocking complete')

    def logout(self) -> None:
        if not self.user_is_logged_in:
            return
        user = self.data.username
        log.info(
            'Logging out user',
            user=user,
        )

        self.deactivate_premium_status()
        self.greenlet_manager.clear()
        del self.chain_manager
        self.exchange_manager.delete_all_exchanges()

        # Reset rotkehlchen logger to default
        LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS)

        del self.accountant
        del self.events_historian
        del self.data_importer

        self.data.logout()
        self.password = ''
        self.cryptocompare.unset_database()

        # Make sure no messages leak to other user sessions
        self.msg_aggregator.consume_errors()
        self.msg_aggregator.consume_warnings()
        self.task_manager = None

        self.user_is_logged_in = False
        log.info(
            'User successfully logged out',
            user=user,
        )

    def set_premium_credentials(self, credentials: PremiumCredentials) -> None:
        """
        Sets the premium credentials for Rotki

        Raises PremiumAuthenticationError if the given key is rejected by the Rotkehlchen server
        """
        log.info('Setting new premium credentials')
        if self.premium is not None:
            self.premium.set_credentials(credentials)
        else:
            self.premium = premium_create_and_verify(credentials)
            self.premium_sync_manager.premium = self.premium
            self.accountant.premium = self.premium

        self.data.db.set_rotkehlchen_premium(credentials)

    def delete_premium_credentials(self) -> Tuple[bool, str]:
        """Deletes the premium credentials for Rotki"""
        msg = ''

        success = self.data.db.del_rotkehlchen_premium()
        if success is False:
            msg = 'The database was unable to delete the Premium keys for the logged-in user'
        self.deactivate_premium_status()
        return success, msg

    def deactivate_premium_status(self) -> None:
        """Deactivate premium in the current session"""
        self.premium = None
        self.premium_sync_manager.premium = None
        self.chain_manager.deactivate_premium_status()
        self.accountant.deactivate_premium_status()

    def start(self) -> gevent.Greenlet:
        assert not self.main_loop_spawned, 'Tried to spawn the main loop twice'
        greenlet = gevent.spawn(self.main_loop)
        self.main_loop_spawned = True
        return greenlet

    def main_loop(self) -> None:
        """Rotki main loop that fires often and runs the task manager's scheduler"""
        while self.shutdown_event.wait(
                timeout=MAIN_LOOP_SECS_DELAY) is not True:
            if self.task_manager is not None:
                self.task_manager.schedule()

    def get_blockchain_account_data(
        self,
        blockchain: SupportedBlockchain,
    ) -> Union[List[BlockchainAccountData], Dict[str, Any]]:
        account_data = self.data.db.get_blockchain_account_data(blockchain)
        if blockchain != SupportedBlockchain.BITCOIN:
            return account_data

        xpub_data = self.data.db.get_bitcoin_xpub_data()
        addresses_to_account_data = {x.address: x for x in account_data}
        address_to_xpub_mappings = self.data.db.get_addresses_to_xpub_mapping(
            list(addresses_to_account_data.keys()),  # type: ignore
        )

        xpub_mappings: Dict['XpubData', List[BlockchainAccountData]] = {}
        for address, xpub_entry in address_to_xpub_mappings.items():
            if xpub_entry not in xpub_mappings:
                xpub_mappings[xpub_entry] = []
            xpub_mappings[xpub_entry].append(
                addresses_to_account_data[address])

        data: Dict[str, Any] = {'standalone': [], 'xpubs': []}
        # Add xpub data
        for xpub_entry in xpub_data:
            data_entry = xpub_entry.serialize()
            addresses = xpub_mappings.get(xpub_entry, None)
            data_entry['addresses'] = addresses if addresses and len(
                addresses) != 0 else None
            data['xpubs'].append(data_entry)
        # Add standalone addresses
        for account in account_data:
            if account.address not in address_to_xpub_mappings:
                data['standalone'].append(account)

        return data

    def add_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        account_data: List[BlockchainAccountData],
    ) -> BlockchainBalancesUpdate:
        """Adds new blockchain accounts

        Adds the accounts to the blockchain instance and queries them to get the
        updated balances. Also adds them in the DB

        May raise:
        - EthSyncError from modify_blockchain_account
        - InputError if the given accounts list is empty.
        - TagConstraintError if any of the given account data contain unknown tags.
        - RemoteError if an external service such as Etherscan is queried and
          there is a problem with its query.
        """
        self.data.db.ensure_tags_exist(
            given_data=account_data,
            action='adding',
            data_type='blockchain accounts',
        )
        address_type = blockchain.get_address_type()
        updated_balances = self.chain_manager.add_blockchain_accounts(
            blockchain=blockchain,
            accounts=[address_type(entry.address) for entry in account_data],
        )
        self.data.db.add_blockchain_accounts(
            blockchain=blockchain,
            account_data=account_data,
        )

        return updated_balances

    def edit_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        account_data: List[BlockchainAccountData],
    ) -> None:
        """Edits blockchain accounts

        Edits blockchain account data for the given accounts

        May raise:
        - InputError if the given accounts list is empty or if
        any of the accounts to edit do not exist.
        - TagConstraintError if any of the given account data contain unknown tags.
        """
        # First check for validity of account data addresses
        if len(account_data) == 0:
            raise InputError(
                'Empty list of blockchain account data to edit was given')
        accounts = [x.address for x in account_data]
        unknown_accounts = set(accounts).difference(
            self.chain_manager.accounts.get(blockchain))
        if len(unknown_accounts) != 0:
            raise InputError(
                f'Tried to edit unknown {blockchain.value} '
                f'accounts {",".join(unknown_accounts)}', )

        self.data.db.ensure_tags_exist(
            given_data=account_data,
            action='editing',
            data_type='blockchain accounts',
        )

        # Finally edit the accounts
        self.data.db.edit_blockchain_accounts(
            blockchain=blockchain,
            account_data=account_data,
        )

    def remove_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        accounts: ListOfBlockchainAddresses,
    ) -> BlockchainBalancesUpdate:
        """Removes blockchain accounts

        Removes the accounts from the blockchain instance and queries them to get
        the updated balances. Also removes them from the DB

        May raise:
        - RemoteError if an external service such as Etherscan is queried and
          there is a problem with its query.
        - InputError if a non-existing account was given to remove
        """
        balances_update = self.chain_manager.remove_blockchain_accounts(
            blockchain=blockchain,
            accounts=accounts,
        )
        self.data.db.remove_blockchain_accounts(blockchain, accounts)
        return balances_update

    def get_history_query_status(self) -> Dict[str, str]:
        if self.events_historian.progress < FVal('100'):
            processing_state = self.events_historian.processing_state_name
            progress = self.events_historian.progress / 2
        elif self.accountant.first_processed_timestamp == -1:
            processing_state = 'Processing all retrieved historical events'
            progress = FVal(50)
        else:
            processing_state = 'Processing all retrieved historical events'
            # start_ts is min of the query start or the first action timestamp since action
            # processing can start well before query start to calculate cost basis
            start_ts = min(
                self.accountant.events.query_start_ts,
                self.accountant.first_processed_timestamp,
            )
            diff = self.accountant.events.query_end_ts - start_ts
            progress = 50 + 100 * (
                FVal(self.accountant.currently_processing_timestamp - start_ts)
                / FVal(diff) / 2)

        return {
            'processing_state': str(processing_state),
            'total_progress': str(progress)
        }

    def process_history(
        self,
        start_ts: Timestamp,
        end_ts: Timestamp,
    ) -> Tuple[Dict[str, Any], str]:
        (
            error_or_empty,
            history,
            loan_history,
            asset_movements,
            eth_transactions,
            defi_events,
            ledger_actions,
        ) = self.events_historian.get_history(
            start_ts=start_ts,
            end_ts=end_ts,
            has_premium=self.premium is not None,
        )
        result = self.accountant.process_history(
            start_ts=start_ts,
            end_ts=end_ts,
            trade_history=history,
            loan_history=loan_history,
            asset_movements=asset_movements,
            eth_transactions=eth_transactions,
            defi_events=defi_events,
            ledger_actions=ledger_actions,
        )
        return result, error_or_empty

    @overload
    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['trade'],
        location_actions: TRADES_LIST,
        all_actions: TRADES_LIST,
    ) -> TRADES_LIST:
        ...

    @overload
    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['asset_movement'],
        location_actions: List[AssetMovement],
        all_actions: List[AssetMovement],
    ) -> List[AssetMovement]:
        ...

    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['trade', 'asset_movement'],
        location_actions: Union[TRADES_LIST, List[AssetMovement]],
        all_actions: Union[TRADES_LIST, List[AssetMovement]],
    ) -> Union[TRADES_LIST, List[AssetMovement]]:
        """Take as many actions from location actions and add them to all actions as the limit permits

        Returns the modified (or not) all_actions
        """
        # If we are already at or above the limit return current actions disregarding this location
        actions_mapping = self.actions_per_location[action_type]
        current_num_actions = sum(x for _, x in actions_mapping.items())
        limit = LIMITS_MAPPING[action_type]
        if current_num_actions >= limit:
            return all_actions

        # Find out how many more actions can we return, and depending on that get
        # the number of actions from the location actions and add them to the total
        remaining_num_actions = limit - current_num_actions
        if remaining_num_actions < 0:
            remaining_num_actions = 0

        num_actions_to_take = min(len(location_actions), remaining_num_actions)

        actions_mapping[location] = num_actions_to_take
        all_actions.extend(
            location_actions[0:num_actions_to_take])  # type: ignore
        return all_actions

    def query_trades(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Optional[Location],
    ) -> TRADES_LIST:
        """Queries trades for the given location and time range.
        If no location is given then all external, all exchange and DEX trades are queried.

        DEX Trades are queried only if the user has premium
        If the user does not have premium then a trade limit is applied.

        May raise:
        - RemoteError: If there are problems connecting to any of the remote exchanges
        """
        trades: TRADES_LIST
        if location is not None:
            trades = self.query_location_trades(from_ts, to_ts, location)
        else:
            trades = self.query_location_trades(from_ts, to_ts,
                                                Location.EXTERNAL)
            # crypto.com is not an API key supported exchange but user can import from CSV
            trades.extend(
                self.query_location_trades(from_ts, to_ts, Location.CRYPTOCOM))
            for name, exchange in self.exchange_manager.connected_exchanges.items(
            ):
                exchange_trades = exchange.query_trade_history(
                    start_ts=from_ts, end_ts=to_ts)
                if self.premium is None:
                    trades = self._apply_actions_limit(
                        location=deserialize_location(name),
                        action_type='trade',
                        location_actions=exchange_trades,
                        all_actions=trades,
                    )
                else:
                    trades.extend(exchange_trades)

            # for all trades we also need uniswap trades
            if self.premium is not None:
                uniswap = self.chain_manager.uniswap
                if uniswap is not None:
                    trades.extend(
                        uniswap.get_trades(
                            addresses=self.chain_manager.
                            queried_addresses_for_module('uniswap'),
                            from_timestamp=from_ts,
                            to_timestamp=to_ts,
                        ), )

        # return trades with most recent first
        trades.sort(key=lambda x: x.timestamp, reverse=True)
        return trades

    def query_location_trades(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Location,
    ) -> TRADES_LIST:
        # clear the trades queried for this location
        self.actions_per_location['trade'][location] = 0

        location_trades: TRADES_LIST
        if location in (Location.EXTERNAL, Location.CRYPTOCOM):
            location_trades = self.data.db.get_trades(  # type: ignore  # list invariance
                from_ts=from_ts,
                to_ts=to_ts,
                location=location,
            )
        elif location == Location.UNISWAP:
            if self.premium is not None:
                uniswap = self.chain_manager.uniswap
                if uniswap is not None:
                    location_trades = uniswap.get_trades(  # type: ignore  # list invariance
                        addresses=self.chain_manager.
                        queried_addresses_for_module('uniswap'),
                        from_timestamp=from_ts,
                        to_timestamp=to_ts,
                    )
        else:
            # should only be an exchange
            exchange = self.exchange_manager.get(str(location))
            if not exchange:
                logger.warning(
                    f'Tried to query trades from {location} which is either not an '
                    f'exchange or not an exchange the user has connected to', )
                return []

            location_trades = exchange.query_trade_history(start_ts=from_ts,
                                                           end_ts=to_ts)

        trades: TRADES_LIST = []
        if self.premium is None:
            trades = self._apply_actions_limit(
                location=location,
                action_type='trade',
                location_actions=location_trades,
                all_actions=trades,
            )
        else:
            trades = location_trades

        return trades

    def query_balances(
        self,
        requested_save_data: bool = False,
        timestamp: Timestamp = None,
        ignore_cache: bool = False,
    ) -> Dict[str, Any]:
        """Query all balances rotkehlchen can see.

        If requested_save_data is True then the data are always saved in the DB,
        if it is False then data are saved if self.data.should_save_balances()
        is True.
        If timestamp is None then the current timestamp is used.
        If a timestamp is given then that is the time that the balances are going
        to be saved in the DB
        If ignore_cache is True then all underlying calls that have a cache ignore it

        Returns a dictionary with the queried balances.
        """
        log.info('query_balances called',
                 requested_save_data=requested_save_data)

        balances: Dict[str, Dict[Asset, Balance]] = {}
        problem_free = True
        for _, exchange in self.exchange_manager.connected_exchanges.items():
            exchange_balances, _ = exchange.query_balances(
                ignore_cache=ignore_cache)
            # If we got an error, disregard that exchange but make sure we don't save data
            if not isinstance(exchange_balances, dict):
                problem_free = False
            else:
                balances[exchange.name] = exchange_balances

        liabilities: Dict[Asset, Balance]
        try:
            blockchain_result = self.chain_manager.query_balances(
                blockchain=None,
                force_token_detection=ignore_cache,
                ignore_cache=ignore_cache,
            )
            balances[str(
                Location.BLOCKCHAIN)] = blockchain_result.totals.assets
            liabilities = blockchain_result.totals.liabilities
        except (RemoteError, EthSyncError) as e:
            problem_free = False
            liabilities = {}
            log.error(f'Querying blockchain balances failed due to: {str(e)}')

        balances = account_for_manually_tracked_balances(db=self.data.db,
                                                         balances=balances)

        # Calculate usd totals
        assets_total_balance: DefaultDict[Asset,
                                          Balance] = defaultdict(Balance)
        total_usd_per_location: Dict[str, FVal] = {}
        for location, asset_balance in balances.items():
            total_usd_per_location[location] = ZERO
            for asset, balance in asset_balance.items():
                assets_total_balance[asset] += balance
                total_usd_per_location[location] += balance.usd_value

        net_usd = sum((balance.usd_value
                       for _, balance in assets_total_balance.items()), ZERO)
        liabilities_total_usd = sum(
            (liability.usd_value for _, liability in liabilities.items()),
            ZERO)  # noqa: E501
        net_usd -= liabilities_total_usd

        # Calculate location stats
        location_stats: Dict[str, Any] = {}
        for location, total_usd in total_usd_per_location.items():
            if location == str(Location.BLOCKCHAIN):
                total_usd -= liabilities_total_usd

            percentage = (total_usd /
                          net_usd).to_percentage() if net_usd != ZERO else '0%'
            location_stats[location] = {
                'usd_value': total_usd,
                'percentage_of_net_value': percentage,
            }

        # Calculate 'percentage_of_net_value' per asset
        assets_total_balance_as_dict: Dict[Asset, Dict[str, Any]] = {
            asset: balance.to_dict()
            for asset, balance in assets_total_balance.items()
        }
        liabilities_as_dict: Dict[Asset, Dict[str, Any]] = {
            asset: balance.to_dict()
            for asset, balance in liabilities.items()
        }
        for asset, balance_dict in assets_total_balance_as_dict.items():
            percentage = (balance_dict['usd_value'] / net_usd).to_percentage(
            ) if net_usd != ZERO else '0%'  # noqa: E501
            assets_total_balance_as_dict[asset][
                'percentage_of_net_value'] = percentage

        for asset, balance_dict in liabilities_as_dict.items():
            percentage = (balance_dict['usd_value'] / net_usd).to_percentage(
            ) if net_usd != ZERO else '0%'  # noqa: E501
            liabilities_as_dict[asset]['percentage_of_net_value'] = percentage

        # Compose balances response
        result_dict = {
            'assets': assets_total_balance_as_dict,
            'liabilities': liabilities_as_dict,
            'location': location_stats,
            'net_usd': net_usd,
        }
        allowed_to_save = requested_save_data or self.data.should_save_balances(
        )

        if problem_free and allowed_to_save:
            if not timestamp:
                timestamp = Timestamp(int(time.time()))
            self.data.db.save_balances_data(data=result_dict,
                                            timestamp=timestamp)
            log.debug('query_balances data saved')
        else:
            log.debug(
                'query_balances data not saved',
                allowed_to_save=allowed_to_save,
                problem_free=problem_free,
            )

        return result_dict

    def _query_exchange_asset_movements(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        all_movements: List[AssetMovement],
        exchange: Union[ExchangeInterface, Location],
    ) -> List[AssetMovement]:
        if isinstance(exchange, ExchangeInterface):
            location = deserialize_location(exchange.name)
            # clear the asset movements queried for this exchange
            self.actions_per_location['asset_movement'][location] = 0
            location_movements = exchange.query_deposits_withdrawals(
                start_ts=from_ts,
                end_ts=to_ts,
            )
        else:
            assert isinstance(exchange,
                              Location), 'only a location should make it here'
            assert exchange == Location.CRYPTOCOM, 'only cryptocom should make it here'
            location = exchange
            # cryptocom has no exchange integration but we may have DB entries
            self.actions_per_location['asset_movement'][location] = 0
            location_movements = self.data.db.get_asset_movements(
                from_ts=from_ts,
                to_ts=to_ts,
                location=location,
            )

        movements: List[AssetMovement] = []
        if self.premium is None:
            movements = self._apply_actions_limit(
                location=location,
                action_type='asset_movement',
                location_actions=location_movements,
                all_actions=all_movements,
            )
        else:
            all_movements.extend(location_movements)
            movements = all_movements

        return movements

    def query_asset_movements(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Optional[Location],
    ) -> List[AssetMovement]:
        """Queries AssetMovements for the given location and time range.

        If no location is given then all exchange asset movements are queried.
        If the user does not have premium then a limit is applied.
        May raise:
        - RemoteError: If there are problems connecting to any of the remote exchanges
        """
        movements: List[AssetMovement] = []
        if location is not None:
            if location == Location.CRYPTOCOM:
                movements = self._query_exchange_asset_movements(
                    from_ts=from_ts,
                    to_ts=to_ts,
                    all_movements=movements,
                    exchange=Location.CRYPTOCOM,
                )
            else:
                exchange = self.exchange_manager.get(str(location))
                if not exchange:
                    logger.warning(
                        f'Tried to query deposits/withdrawals from {location} which is either '
                        f'not at exchange or not an exchange the user has connected to',
                    )
                    return []
                movements = self._query_exchange_asset_movements(
                    from_ts=from_ts,
                    to_ts=to_ts,
                    all_movements=movements,
                    exchange=exchange,
                )
        else:
            # cryptocom has no exchange integration but we may have DB entries due to csv import
            movements = self._query_exchange_asset_movements(
                from_ts=from_ts,
                to_ts=to_ts,
                all_movements=movements,
                exchange=Location.CRYPTOCOM,
            )
            for _, exchange in self.exchange_manager.connected_exchanges.items(
            ):
                self._query_exchange_asset_movements(
                    from_ts=from_ts,
                    to_ts=to_ts,
                    all_movements=movements,
                    exchange=exchange,
                )

        # return movements with most recent first
        movements.sort(key=lambda x: x.timestamp, reverse=True)
        return movements

    def set_settings(self, settings: ModifiableDBSettings) -> Tuple[bool, str]:
        """Tries to set new settings. Returns True in success or False with message if error"""
        with self.lock:
            if settings.eth_rpc_endpoint is not None:
                result, msg = self.chain_manager.set_eth_rpc_endpoint(
                    settings.eth_rpc_endpoint)
                if not result:
                    return False, msg

            if settings.ksm_rpc_endpoint is not None:
                result, msg = self.chain_manager.set_ksm_rpc_endpoint(
                    settings.ksm_rpc_endpoint)
                if not result:
                    return False, msg

            if settings.kraken_account_type is not None:
                kraken = self.exchange_manager.get('kraken')
                if kraken:
                    kraken.set_account_type(
                        settings.kraken_account_type)  # type: ignore

            if settings.btc_derivation_gap_limit is not None:
                self.chain_manager.btc_derivation_gap_limit = settings.btc_derivation_gap_limit

            if settings.current_price_oracles is not None:
                Inquirer().set_oracles_order(settings.current_price_oracles)

            if settings.historical_price_oracles is not None:
                PriceHistorian().set_oracles_order(
                    settings.historical_price_oracles)

            self.data.db.set_settings(settings)
            return True, ''

    def get_settings(self) -> DBSettings:
        """Returns the db settings with a check whether premium is active or not"""
        db_settings = self.data.db.get_settings(
            have_premium=self.premium is not None)
        return db_settings

    def setup_exchange(
        self,
        name: str,
        api_key: ApiKey,
        api_secret: ApiSecret,
        passphrase: Optional[str] = None,
    ) -> Tuple[bool, str]:
        """
        Setup a new exchange with an api key and an api secret and optionally a passphrase
        """
        is_success, msg = self.exchange_manager.setup_exchange(
            name=name,
            api_key=api_key,
            api_secret=api_secret,
            database=self.data.db,
            passphrase=passphrase,
        )

        if is_success:
            # Success, save the result in the DB
            self.data.db.add_exchange(name,
                                      api_key,
                                      api_secret,
                                      passphrase=passphrase)
        return is_success, msg

    def remove_exchange(self, name: str) -> Tuple[bool, str]:
        if not self.exchange_manager.has_exchange(name):
            return False, 'Exchange {} is not registered'.format(name)

        self.exchange_manager.delete_exchange(name)
        # Success, remove it also from the DB
        self.data.db.remove_exchange(name)
        self.data.db.delete_used_query_range_for_exchange(name)
        return True, ''

    def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]:
        """Query for frequently changing data"""
        result: Dict[str, Union[bool, Timestamp]] = {}

        if self.user_is_logged_in:
            result[
                'last_balance_save'] = self.data.db.get_last_balance_save_time(
                )
            result[
                'eth_node_connection'] = self.chain_manager.ethereum.web3_mapping.get(
                    NodeName.OWN, None) is not None  # noqa : E501
            result['last_data_upload_ts'] = Timestamp(
                self.premium_sync_manager.last_data_upload_ts)  # noqa : E501
        return result

    def shutdown(self) -> None:
        self.logout()
        self.shutdown_event.set()

    def _connect_ksm_manager_on_startup(self) -> bool:
        return bool(self.data.db.get_blockchain_accounts().ksm)

    def create_oracle_cache(
        self,
        oracle: HistoricalPriceOracle,
        from_asset: Asset,
        to_asset: Asset,
        purge_old: bool,
    ) -> None:
        """Creates the cache of the given asset pair from the start of time
        until now for the given oracle.

        if purge_old is true then any old cache in memory and in a file is purged

        May raise:
            - RemoteError if there is a problem reaching the oracle
            - UnsupportedAsset if any of the two assets is not supported by the oracle
        """
        if oracle != HistoricalPriceOracle.CRYPTOCOMPARE:
            return  # only for cryptocompare for now

        self.cryptocompare.create_cache(from_asset, to_asset, purge_old)

    def delete_oracle_cache(
        self,
        oracle: HistoricalPriceOracle,
        from_asset: Asset,
        to_asset: Asset,
    ) -> None:
        if oracle != HistoricalPriceOracle.CRYPTOCOMPARE:
            return  # only for cryptocompare for now

        self.cryptocompare.delete_cache(from_asset, to_asset)

    def get_oracle_cache(
            self, oracle: HistoricalPriceOracle) -> List[Dict[str, Any]]:
        if oracle != HistoricalPriceOracle.CRYPTOCOMPARE:
            return []  # only for cryptocompare for now

        return self.cryptocompare.get_all_cache_data()
コード例 #37
0
ファイル: Lock1.py プロジェクト: isoundy000/learn_python
class Lock1:

    _gevent_locks = {}
    _gevent_lock = Semaphore()

    @staticmethod
    def AddCurrent(lock):
        # current = gevent.getcurrent()
        # Lock1._gevent_lock.acquire()
        # if not Lock1._gevent_locks.has_key(current):
        #     Lock1._gevent_locks[current] = set()
        #     Lock1._gevent_locks[current].add(lock)
        # Lock1._gevent_lock.release()
        pass

    @staticmethod
    def SubCurrent(lock):
        # current = gevent.getcurrent()
        # Lock1._gevent_lock.acquire()
        # Lock1._gevent_locks[current].discard(lock)
        # Lock1._gevent_lock.release()
        pass

    @staticmethod
    def ClearCurrent():
        # current = gevent.getcurrent()
        # Lock1._gevent_lock.acquire()
        # if Lock1._gevent_locks.has_key(current):
        #     while True:
        #         if len(Lock1._gevent_locks[current]) == 0:
        #             break
        #         locktmp = Lock1._gevent_locks[current].pop()
        #         locktmp.ReleaseEx()
        # Lock1._gevent_lock.release()
        pass

    def __init__(self):
        self._lock = Semaphore()
        self._current = None
        self._count = 0

    def Lock(self, tag=None):
        # if tag:
        #     print "[Lock1]" + tag + " 请求"
        # current = gevent.getcurrent()
        # if self._current == current:
        #     self._count += 1
        #     return
        # if not self._lock.acquire(timeout=1):
        #     raise Exception("lock error")
        # self._current = current
        # self._count = 1
        # Lock1.AddCurrent(self)
        #
        # if tag:
        #     print "[Lock1]" + tag + " 确认"
        pass

    def Release(self, tag=None):
        # self._count -= 1
        # if self._count == 0:
        #     self._lock.release()
        #     self._current = None
        #     Lock1.SubCurrent(self)
        # if tag:
        #     print "[Lock1]" + tag + " 释放"
        pass

    def ReleaseEx(self):
        # self._lock.release()
        # self._current = None
        # self._count = 0
        pass
コード例 #38
0
import logging
try:
    from gevent.lock import Semaphore
except ImportError:
    from eventlet.semaphore import Semaphore

from django.db.backends.mysql.base import DatabaseWrapper as OriginalDatabaseWrapper
from .creation import DatabaseCreation
from .connection_pool import MysqlConnectionPool

logger = logging.getLogger('django.geventpool')

connection_pools = {}
connection_pools_lock = Semaphore(value=1)

DEFAULT_MAX_CONNS = 4


class ConnectionPoolMixin(object):
    creation_class = DatabaseCreation

    def __init__(self, settings_dict, *args, **kwargs):
        def pop_max_conn(settings_dict):
            if "OPTIONS" in settings_dict:
                return settings_dict["OPTIONS"].pop("MAX_CONNS",
                                                    DEFAULT_MAX_CONNS)
            else:
                return DEFAULT_MAX_CONNS

        self._pool = None
        settings_dict['CONN_MAX_AGE'] = 0
コード例 #39
0
ファイル: rotkehlchen.py プロジェクト: resslerruntime/rotki
    def __init__(self, args: argparse.Namespace) -> None:
        """Initialize the Rotkehlchen object

        May Raise:
        - SystemPermissionError if the given data directory's permissions
        are not correct.
        """
        self.lock = Semaphore()
        self.lock.acquire()

        # Can also be None after unlock if premium credentials did not
        # authenticate or premium server temporarily offline
        self.premium: Optional[Premium] = None
        self.user_is_logged_in: bool = False
        configure_logging(args)

        self.sleep_secs = args.sleep_secs
        if args.data_dir is None:
            self.data_dir = default_data_directory()
        else:
            self.data_dir = Path(args.data_dir)

        if not os.access(self.data_dir, os.W_OK | os.R_OK):
            raise SystemPermissionError(
                f'The given data directory {self.data_dir} is not readable or writable',
            )
        self.args = args
        self.msg_aggregator = MessagesAggregator()
        self.greenlet_manager = GreenletManager(
            msg_aggregator=self.msg_aggregator)
        self.exchange_manager = ExchangeManager(
            msg_aggregator=self.msg_aggregator)
        # Initialize the AssetResolver singleton
        AssetResolver(data_directory=self.data_dir)
        self.data = DataHandler(self.data_dir, self.msg_aggregator)
        self.cryptocompare = Cryptocompare(data_directory=self.data_dir,
                                           database=None)
        self.coingecko = Coingecko()
        self.icon_manager = IconManager(data_dir=self.data_dir,
                                        coingecko=self.coingecko)
        self.greenlet_manager.spawn_and_track(
            after_seconds=None,
            task_name='periodically_query_icons_until_all_cached',
            method=self.icon_manager.periodically_query_icons_until_all_cached,
            batch_size=ICONS_BATCH_SIZE,
            sleep_time_secs=ICONS_QUERY_SLEEP,
        )
        # Initialize the Inquirer singleton
        Inquirer(
            data_dir=self.data_dir,
            cryptocompare=self.cryptocompare,
            coingecko=self.coingecko,
        )
        # Keeps how many trades we have found per location. Used for free user limiting
        self.actions_per_location: Dict[str, Dict[Location, int]] = {
            'trade': defaultdict(int),
            'asset_movement': defaultdict(int),
        }

        self.lock.release()
        self.shutdown_event = gevent.event.Event()
コード例 #40
0
ファイル: test__semaphore.py プロジェクト: sigshen/gevent
 def test_acquire_returns_false_after_timeout(self):
     s = Semaphore(value=0)
     result = s.acquire(timeout=0.01)
     assert result is False, repr(result)
コード例 #41
0
ファイル: client.py プロジェクト: infuy/lumino_fork
    def __init__(
        self,
        web3: Web3,
        privkey: bytes,
        gas_price_strategy: Callable = rpc_gas_price_strategy,
        gas_estimate_correction: Callable = lambda gas: gas,
        block_num_confirmations: int = 0,
        uses_infura=False,
    ):
        if privkey is None or len(privkey) != 32:
            raise ValueError('Invalid private key')

        if block_num_confirmations < 0:
            raise ValueError('Number of confirmations has to be positive', )

        monkey_patch_web3(web3, gas_price_strategy)

        try:
            version = web3.version.node
        except ConnectTimeout:
            raise EthNodeCommunicationError('couldnt reach the ethereum node')

        _, eth_node = is_supported_client(version)

        address = privatekey_to_address(privkey)
        address_checksumed = to_checksum_address(address)

        if uses_infura:
            warnings.warn(
                'Infura does not provide an API to '
                'recover the latest used nonce. This may cause the Raiden node '
                'to error on restarts.\n'
                'The error will manifest while there is a pending transaction '
                'from a previous execution in the Ethereum\'s client pool. When '
                'Raiden restarts the same transaction with the same nonce will '
                'be retried and *rejected*, because the nonce is already used.',
            )
            # The first valid nonce is 0, therefore the count is already the next
            # available nonce
            available_nonce = web3.eth.getTransactionCount(
                address_checksumed, 'pending')

        elif eth_node == constants.EthClient.PARITY:
            parity_assert_rpc_interfaces(web3)
            available_nonce = parity_discover_next_available_nonce(
                web3,
                address_checksumed,
            )

        elif eth_node == constants.EthClient.GETH:
            geth_assert_rpc_interfaces(web3)
            available_nonce = geth_discover_next_available_nonce(
                web3,
                address_checksumed,
            )

        else:
            raise EthNodeInterfaceError(
                f'Unsupported Ethereum client {version}')

        self.eth_node = eth_node
        self.privkey = privkey
        self.address = address
        self.web3 = web3
        self.default_block_num_confirmations = block_num_confirmations

        self._available_nonce = available_nonce
        self._nonce_lock = Semaphore()
        self._gas_estimate_correction = gas_estimate_correction

        log.debug(
            'JSONRPCClient created',
            node=pex(self.address),
            available_nonce=available_nonce,
            client=version,
        )
コード例 #42
0
    def __init__(
            self,
            chain: BlockChainService,
            query_start_block: BlockNumber,
            default_registry: TokenNetworkRegistry,
            default_secret_registry: SecretRegistry,
            transport,
            raiden_event_handler,
            message_handler,
            config,
            discovery=None,
    ):
        super().__init__()
        self.tokennetworkids_to_connectionmanagers = dict()
        self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict(dict)

        self.chain: BlockChainService = chain
        self.default_registry = default_registry
        self.query_start_block = query_start_block
        self.default_secret_registry = default_secret_registry
        self.config = config

        self.signer: Signer = LocalSigner(self.chain.client.privkey)
        self.address = self.signer.address
        self.discovery = discovery
        self.transport = transport

        self.blockchain_events = BlockchainEvents()
        self.alarm = AlarmTask(chain)
        self.raiden_event_handler = raiden_event_handler
        self.message_handler = message_handler

        self.stop_event = Event()
        self.stop_event.set()  # inits as stopped

        self.wal = None
        self.snapshot_group = 0

        # This flag will be used to prevent the service from processing
        # state changes events until we know that pending transactions
        # have been dispatched.
        self.dispatch_events_lock = Semaphore(1)

        self.contract_manager = ContractManager(config['contracts_path'])
        self.database_path = config['database_path']
        if self.database_path != ':memory:':
            database_dir = os.path.dirname(config['database_path'])
            os.makedirs(database_dir, exist_ok=True)

            self.database_dir = database_dir

            # Two raiden processes must not write to the same database, even
            # though the database itself may be consistent. If more than one
            # nodes writes state changes to the same WAL there are no
            # guarantees about recovery, this happens because during recovery
            # the WAL replay can not be deterministic.
            self.lock_file = os.path.join(self.database_dir, '.lock')
            self.db_lock = filelock.FileLock(self.lock_file)
        else:
            self.database_path = ':memory:'
            self.database_dir = None
            self.lock_file = None
            self.serialization_file = None
            self.db_lock = None

        self.event_poll_lock = gevent.lock.Semaphore()
        self.gas_reserve_lock = gevent.lock.Semaphore()
        self.payment_identifier_lock = gevent.lock.Semaphore()
コード例 #43
0
class ThreadPool(object):
    def __init__(self, maxsize, hub=None):
        if hub is None:
            hub = get_hub()
        self.hub = hub
        self._maxsize = 0
        self.manager = None
        self.pid = os.getpid()
        self.fork_watcher = hub.loop.fork(ref=False)
        self._init(maxsize)

    def _set_maxsize(self, maxsize):
        if not isinstance(maxsize, integer_types):
            raise TypeError('maxsize must be integer: %r' % (maxsize, ))
        if maxsize < 0:
            raise ValueError('maxsize must not be negative: %r' % (maxsize, ))
        difference = maxsize - self._maxsize
        self._semaphore.counter += difference
        self._maxsize = maxsize
        self.adjust()
        # make sure all currently blocking spawn() start unlocking if maxsize increased
        self._semaphore._start_notify()

    def _get_maxsize(self):
        return self._maxsize

    maxsize = property(_get_maxsize, _set_maxsize)

    def __repr__(self):
        return '<%s at 0x%x %s/%s/%s>' % (self.__class__.__name__, id(self),
                                          len(self), self.size, self.maxsize)

    def __len__(self):
        # XXX just do unfinished_tasks property
        return self.task_queue.unfinished_tasks

    def _get_size(self):
        return self._size

    def _set_size(self, size):
        if size < 0:
            raise ValueError('Size of the pool cannot be negative: %r' %
                             (size, ))
        if size > self._maxsize:
            raise ValueError(
                'Size of the pool cannot be bigger than maxsize: %r > %r' %
                (size, self._maxsize))
        if self.manager:
            self.manager.kill()
        while self._size < size:
            self._add_thread()
        delay = 0.0001
        while self._size > size:
            while self._size - size > self.task_queue.unfinished_tasks:
                self.task_queue.put(None)
            if getcurrent() is self.hub:
                break
            sleep(delay)
            delay = min(delay * 2, .05)
        if self._size:
            self.fork_watcher.start(self._on_fork)
        else:
            self.fork_watcher.stop()

    size = property(_get_size, _set_size)

    def _init(self, maxsize):
        self._size = 0
        self._semaphore = Semaphore(1)
        self._lock = Lock()
        self.task_queue = Queue()
        self._set_maxsize(maxsize)

    def _on_fork(self):
        # fork() only leaves one thread; also screws up locks;
        # let's re-create locks and threads
        pid = os.getpid()
        if pid != self.pid:
            self.pid = pid
            # Do not mix fork() and threads; since fork() only copies one thread
            # all objects referenced by other threads has refcount that will never
            # go down to 0.
            self._init(self._maxsize)

    def join(self):
        delay = 0.0005
        while self.task_queue.unfinished_tasks > 0:
            sleep(delay)
            delay = min(delay * 2, .05)

    def kill(self):
        self.size = 0

    def _adjust_step(self):
        # if there is a possibility & necessity for adding a thread, do it
        while self._size < self._maxsize and self.task_queue.unfinished_tasks > self._size:
            self._add_thread()
        # while the number of threads is more than maxsize, kill one
        # we do not check what's already in task_queue - it could be all Nones
        while self._size - self._maxsize > self.task_queue.unfinished_tasks:
            self.task_queue.put(None)
        if self._size:
            self.fork_watcher.start(self._on_fork)
        else:
            self.fork_watcher.stop()

    def _adjust_wait(self):
        delay = 0.0001
        while True:
            self._adjust_step()
            if self._size <= self._maxsize:
                return
            sleep(delay)
            delay = min(delay * 2, .05)

    def adjust(self):
        self._adjust_step()
        if not self.manager and self._size > self._maxsize:
            # might need to feed more Nones into the pool
            self.manager = Greenlet.spawn(self._adjust_wait)

    def _add_thread(self):
        with self._lock:
            self._size += 1
        try:
            start_new_thread(self._worker, ())
        except:
            with self._lock:
                self._size -= 1
            raise

    def spawn(self, func, *args, **kwargs):
        while True:
            semaphore = self._semaphore
            semaphore.acquire()
            if semaphore is self._semaphore:
                break
        try:
            task_queue = self.task_queue
            result = AsyncResult()
            thread_result = ThreadResult(result, hub=self.hub)
            task_queue.put((func, args, kwargs, thread_result))
            self.adjust()
            # rawlink() must be the last call
            result.rawlink(lambda *args: self._semaphore.release())
            # XXX this _semaphore.release() is competing for order with get()
            # XXX this is not good, just make ThreadResult release the semaphore before doing anything else
        except:
            semaphore.release()
            raise
        return result

    def _decrease_size(self):
        if sys is None:
            return
        _lock = getattr(self, '_lock', None)
        if _lock is not None:
            with _lock:
                self._size -= 1

    def _worker(self):
        need_decrease = True
        try:
            while True:
                task_queue = self.task_queue
                task = task_queue.get()
                try:
                    if task is None:
                        need_decrease = False
                        self._decrease_size()
                        # we want first to decrease size, then decrease unfinished_tasks
                        # otherwise, _adjust might think there's one more idle thread that
                        # needs to be killed
                        return
                    func, args, kwargs, result = task
                    try:
                        value = func(*args, **kwargs)
                    except:
                        exc_info = getattr(sys, 'exc_info', None)
                        if exc_info is None:
                            return
                        result.handle_error((self, func), exc_info())
                    else:
                        if sys is None:
                            return
                        result.set(value)
                        del value
                    finally:
                        del func, args, kwargs, result, task
                finally:
                    if sys is None:
                        return
                    task_queue.task_done()
        finally:
            if need_decrease:
                self._decrease_size()

    # XXX apply() should re-raise error by default
    # XXX because that's what builtin apply does
    # XXX check gevent.pool.Pool.apply and multiprocessing.Pool.apply
    def apply_e(self, expected_errors, function, args=None, kwargs=None):
        if args is None:
            args = ()
        if kwargs is None:
            kwargs = {}
        success, result = self.spawn(wrap_errors, expected_errors, function,
                                     args, kwargs).get()
        if success:
            return result
        raise result

    def apply(self, func, args=None, kwds=None):
        """Equivalent of the apply() builtin function. It blocks till the result is ready."""
        if args is None:
            args = ()
        if kwds is None:
            kwds = {}
        return self.spawn(func, *args, **kwds).get()

    def apply_cb(self, func, args=None, kwds=None, callback=None):
        result = self.apply(func, args, kwds)
        if callback is not None:
            callback(result)
        return result

    def apply_async(self, func, args=None, kwds=None, callback=None):
        """A variant of the apply() method which returns a Greenlet object.

        If callback is specified then it should be a callable which accepts a single argument. When the result becomes ready
        callback is applied to it (unless the call failed)."""
        if args is None:
            args = ()
        if kwds is None:
            kwds = {}
        return Greenlet.spawn(self.apply_cb, func, args, kwds, callback)

    def map(self, func, iterable):
        return list(self.imap(func, iterable))

    def map_cb(self, func, iterable, callback=None):
        result = self.map(func, iterable)
        if callback is not None:
            callback(result)
        return result

    def map_async(self, func, iterable, callback=None):
        """
        A variant of the map() method which returns a Greenlet object.

        If callback is specified then it should be a callable which accepts a
        single argument.
        """
        return Greenlet.spawn(self.map_cb, func, iterable, callback)

    def imap(self, func, iterable):
        """An equivalent of itertools.imap()"""
        return IMap.spawn(func, iterable, spawn=self.spawn)

    def imap_unordered(self, func, iterable):
        """The same as imap() except that the ordering of the results from the
        returned iterator should be considered in arbitrary order."""
        return IMapUnordered.spawn(func, iterable, spawn=self.spawn)
コード例 #44
0
 def __init__(self,conn):
     self.conn = conn
     self.closed = False
     self.connerr = None
     self.conn_mutex = RLock()
     self.conn_cond = Semaphore(0)
コード例 #45
0
ファイル: rpc_server.py プロジェクト: zhouqiang-cl/wloki
class RpcConn(object):
    def __init__(self, sck, server):
        self._sck = sck
        self._rlock = Semaphore()
        self._wlock = Semaphore()
        self._server = server
        self._living_controllers = {}

    def on_reqeust(self, request):
        service = self.get_service(request.service_identifier)
        if service == None:
            return

        method = self.get_service_method(service, request.method_identifier)
        if method == None:
            return

        proto_request = self.get_proto_request(service, method, request)

        req_id = request.call_id
        controller = RpcController()
        self._living_controllers[req_id] = controller

        g = gevent.spawn(self.call_method_and_reply, req_id, service, method,
                         controller, proto_request)
        g.start()

    def call_method_and_reply(self, req_id, service, method, controller,
                              proto_request):
        callback = Callback()
        service.CallMethod(method, controller, proto_request, callback)

        if controller.IsCanceled():
            del self._living_controllers[req_id]
            return

        payload = WirePayload()
        if callback.response != None and controller.Failed() == False:
            resp = RpcResponse()
            resp.response_bytes = callback.response.SerializeToString()
            resp.call_id = req_id

            payload.rpc_response.call_id = req_id
            payload.rpc_response.response_bytes = resp.response_bytes
        else:
            if controller.Failed():
                payload.rpc_error.call_id = req_id
                payload.rpc_error.error = controller.ErrorText()

        buf = payload.SerializeToString()

        self._wlock.acquire()
        self._sck.sendall(utils.int32_to_bytes(len(buf)) + buf)
        self._wlock.release()

        # finish this call
        del self._living_controllers[req_id]

    def on_cancel(self, rpc_cancel):
        controller = self._living_controllers.get(rpc_cancel.call_id_to_cancel)
        if controller != None:
            controller.StartCancel()

    def get_service(self, service_name):
        service = self._server._serivces.get(service_name)
        return service

    def get_service_method(self, service, method_name):
        method = service.DESCRIPTOR.FindMethodByName(method_name)
        return method

    def get_proto_request(self, service, method, request):
        proto_request = service.GetRequestClass(method)()
        proto_request.ParseFromString(request.request_bytes)

        # Check the request parsed correctly
        if not proto_request.IsInitialized():
            return None

        return proto_request

    def run(self):
        while True:
            # read payload package
            self._rlock.acquire()
            sz = utils.read_int32(self._sck)
            if sz == None:
                self._rlock.release()
                break
            buf = utils.readall(self._sck, sz)
            if buf == None:
                self._rlock.release()
                break
            self._rlock.release()

            payload = WirePayload()
            payload.ParseFromString(buf)

            # if has rpc request
            if payload.rpc_cancel.IsInitialized():
                self.on_cancel(payload.rpc_cancel)
            elif payload.rpc_request.IsInitialized():
                self.on_reqeust(payload.rpc_request)
コード例 #46
0
ファイル: console_ui.py プロジェクト: cwipy/libauto
            draw_all()

    def exposed_set_battery_percent(self, pct):
        """
        `pct` should be an integer in [0, 100].
        """
        with self.lock:
            if not isinstance(pct, int) or not (0 <= pct <= 100):
                raise Exception("Invalid battery percent")
            pct = "{}%".format(pct)
            global battery_sprite
            battery_sprite = header_font.render(pct, True, HEADER_TXT_COLOR)
            draw_all()


from rpyc.utils.server import GeventServer
from rpyc.utils.helpers import classpartial

global_lock = Semaphore(value=1)

ConsoleService = classpartial(ConsoleService, global_lock)

rpc_server = GeventServer(ConsoleService, port=18863)

log.info("RUNNING!")

gevent.joinall([
    gevent.spawn(rpc_server.start),
])

コード例 #47
0
class Rotkehlchen():
    def __init__(self, args: argparse.Namespace) -> None:
        self.lock = Semaphore()
        self.lock.acquire()

        self.premium = None
        self.user_is_logged_in = False

        logfilename = None
        if args.logtarget == 'file':
            logfilename = args.logfile

        if args.loglevel == 'debug':
            loglevel = logging.DEBUG
        elif args.loglevel == 'info':
            loglevel = logging.INFO
        elif args.loglevel == 'warn':
            loglevel = logging.WARN
        elif args.loglevel == 'error':
            loglevel = logging.ERROR
        elif args.loglevel == 'critical':
            loglevel = logging.CRITICAL
        else:
            raise ValueError('Should never get here. Illegal log value')

        logging.basicConfig(
            filename=logfilename,
            filemode='w',
            level=loglevel,
            format='%(asctime)s -- %(levelname)s:%(name)s:%(message)s',
            datefmt='%d/%m/%Y %H:%M:%S %Z',
        )

        if not args.logfromothermodules:
            logging.getLogger('zerorpc').setLevel(logging.CRITICAL)
            logging.getLogger('zerorpc.channel').setLevel(logging.CRITICAL)
            logging.getLogger('urllib3').setLevel(logging.CRITICAL)
            logging.getLogger('urllib3.connectionpool').setLevel(
                logging.CRITICAL)

        self.sleep_secs = args.sleep_secs
        self.data_dir = args.data_dir
        self.args = args
        self.msg_aggregator = MessagesAggregator()
        self.exchange_manager = ExchangeManager(
            msg_aggregator=self.msg_aggregator)
        self.data = DataHandler(self.data_dir, self.msg_aggregator)
        # Initialize the Inquirer singleton
        Inquirer(data_dir=self.data_dir)

        self.lock.release()
        self.shutdown_event = gevent.event.Event()

    def unlock_user(
        self,
        user: str,
        password: str,
        create_new: bool,
        sync_approval: str,
        premium_credentials: Optional[PremiumCredentials],
    ) -> None:
        """Unlocks an existing user or creates a new one if `create_new` is True"""
        log.info(
            'Unlocking user',
            user=user,
            create_new=create_new,
            sync_approval=sync_approval,
        )
        # unlock or create the DB
        self.password = password
        self.user_directory = self.data.unlock(user, password, create_new)
        self.data_importer = DataImporter(db=self.data.db)
        self.last_data_upload_ts = self.data.db.get_last_data_upload_ts()
        self.premium_sync_manager = PremiumSyncManager(data=self.data,
                                                       password=password)

        try:
            self.premium = self.premium_sync_manager.try_premium_at_start(
                given_premium_credentials=premium_credentials,
                username=user,
                create_new=create_new,
                sync_approval=sync_approval,
            )
        except AuthenticationError:
            # It means that our credentials were not accepted by the server
            # or some other error happened
            pass

        settings = self.data.db.get_settings()
        maybe_submit_usage_analytics(settings.submit_usage_analytics)
        historical_data_start = settings.historical_data_start
        eth_rpc_endpoint = settings.eth_rpc_endpoint
        self.trades_historian = TradesHistorian(
            user_directory=self.user_directory,
            db=self.data.db,
            eth_accounts=self.data.get_eth_accounts(),
            msg_aggregator=self.msg_aggregator,
            exchange_manager=self.exchange_manager,
        )
        # Initialize the price historian singleton
        PriceHistorian(
            data_directory=self.data_dir,
            history_date_start=historical_data_start,
            cryptocompare=Cryptocompare(data_directory=self.data_dir),
        )
        db_settings = self.data.db.get_settings()
        self.accountant = Accountant(
            profit_currency=self.data.main_currency(),
            user_directory=self.user_directory,
            msg_aggregator=self.msg_aggregator,
            create_csv=True,
            ignored_assets=self.data.db.get_ignored_assets(),
            include_crypto2crypto=db_settings.include_crypto2crypto,
            taxfree_after_period=db_settings.taxfree_after_period,
            include_gas_costs=db_settings.include_gas_costs,
        )

        # Initialize the rotkehlchen logger
        LoggingSettings(anonymized_logs=db_settings.anonymized_logs)
        exchange_credentials = self.data.db.get_exchange_credentials()
        self.exchange_manager.initialize_exchanges(
            exchange_credentials=exchange_credentials,
            database=self.data.db,
        )

        ethchain = Ethchain(eth_rpc_endpoint)
        self.blockchain = Blockchain(
            blockchain_accounts=self.data.db.get_blockchain_accounts(),
            owned_eth_tokens=self.data.db.get_owned_tokens(),
            ethchain=ethchain,
            msg_aggregator=self.msg_aggregator,
        )
        self.user_is_logged_in = True

    def logout(self) -> None:
        if not self.user_is_logged_in:
            return

        user = self.data.username
        log.info(
            'Logging out user',
            user=user,
        )
        del self.blockchain
        self.exchange_manager.delete_all_exchanges()

        # Reset rotkehlchen logger to default
        LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS)

        del self.accountant
        del self.trades_historian
        del self.data_importer

        if self.premium is not None:
            # For some reason mypy does not see that self.premium is set
            del self.premium  # type: ignore
        self.data.logout()
        self.password = ''

        self.user_is_logged_in = False
        log.info(
            'User successfully logged out',
            user=user,
        )

    def set_premium_credentials(self, credentials: PremiumCredentials) -> None:
        """
        Sets the premium credentials for Rotki

        Raises AuthenticationError if the given key is rejected by the Rotkehlchen server
        """
        log.info('Setting new premium credentials')
        if self.premium is not None:
            # For some reason mypy does not see that self.premium is set
            self.premium.set_credentials(credentials)  # type: ignore
        else:
            self.premium = premium_create_and_verify(credentials)

        self.data.db.set_rotkehlchen_premium(credentials)

    def start(self) -> gevent.Greenlet:
        return gevent.spawn(self.main_loop)

    def main_loop(self) -> None:
        while self.shutdown_event.wait(MAIN_LOOP_SECS_DELAY) is not True:
            if self.user_is_logged_in:
                log.debug('Main loop start')
                self.premium_sync_manager.maybe_upload_data_to_server()
                log.debug('Main loop end')

    def add_blockchain_account(
        self,
        blockchain: SupportedBlockchain,
        account: BlockchainAddress,
    ) -> Dict:
        try:
            new_data = self.blockchain.add_blockchain_account(
                blockchain, account)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))
        self.data.add_blockchain_account(blockchain, account)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def remove_blockchain_account(
        self,
        blockchain: SupportedBlockchain,
        account: BlockchainAddress,
    ) -> Dict[str, Any]:
        try:
            new_data = self.blockchain.remove_blockchain_account(
                blockchain, account)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))
        self.data.remove_blockchain_account(blockchain, account)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def add_owned_eth_tokens(self, tokens: List[str]) -> Dict[str, Any]:
        ethereum_tokens = [
            EthereumToken(identifier=identifier) for identifier in tokens
        ]
        try:
            new_data = self.blockchain.track_new_tokens(ethereum_tokens)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))

        self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def remove_owned_eth_tokens(self, tokens: List[str]) -> Dict[str, Any]:
        ethereum_tokens = [
            EthereumToken(identifier=identifier) for identifier in tokens
        ]
        try:
            new_data = self.blockchain.remove_eth_tokens(ethereum_tokens)
        except InputError as e:
            return simple_result(False, str(e))
        self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def process_history(
        self,
        start_ts: Timestamp,
        end_ts: Timestamp,
    ) -> Tuple[Dict[str, Any], str]:
        (
            error_or_empty,
            history,
            loan_history,
            asset_movements,
            eth_transactions,
        ) = self.trades_historian.get_history(
            # For entire history processing we need to have full history available
            start_ts=Timestamp(0),
            end_ts=ts_now(),
        )
        result = self.accountant.process_history(
            start_ts=start_ts,
            end_ts=end_ts,
            trade_history=history,
            loan_history=loan_history,
            asset_movements=asset_movements,
            eth_transactions=eth_transactions,
        )
        return result, error_or_empty

    def query_fiat_balances(self) -> Dict[Asset, Dict[str, FVal]]:
        result = {}
        balances = self.data.get_fiat_balances()
        for currency, str_amount in balances.items():
            amount = FVal(str_amount)
            usd_rate = Inquirer().query_fiat_pair(currency, A_USD)
            result[currency] = {
                'amount': amount,
                'usd_value': amount * usd_rate,
            }

        return result

    def query_balances(
        self,
        requested_save_data: bool = False,
        timestamp: Timestamp = None,
    ) -> Dict[str, Any]:
        """Query all balances rotkehlchen can see.

        If requested_save_data is True then the data are saved in the DB.
        If timestamp is None then the current timestamp is used.
        If a timestamp is given then that is the time that the balances are going
        to be saved in the DB

        Returns a dictionary with the queried balances.
        """
        log.info('query_balances called',
                 requested_save_data=requested_save_data)

        balances = {}
        problem_free = True
        for _, exchange in self.exchange_manager.connected_exchanges.items():
            exchange_balances, _ = exchange.query_balances()
            # If we got an error, disregard that exchange but make sure we don't save data
            if not isinstance(exchange_balances, dict):
                problem_free = False
            else:
                balances[exchange.name] = exchange_balances

        result, error_or_empty = self.blockchain.query_balances()
        if error_or_empty == '':
            balances['blockchain'] = result['totals']
        else:
            problem_free = False

        result = self.query_fiat_balances()
        if result != {}:
            balances['banks'] = result

        combined = combine_stat_dicts([v for k, v in balances.items()])
        total_usd_per_location = [(k, dict_get_sumof(v, 'usd_value'))
                                  for k, v in balances.items()]

        # calculate net usd value
        net_usd = FVal(0)
        for _, v in combined.items():
            net_usd += FVal(v['usd_value'])

        stats: Dict[str, Any] = {
            'location': {},
            'net_usd': net_usd,
        }
        for entry in total_usd_per_location:
            name = entry[0]
            total = entry[1]
            if net_usd != FVal(0):
                percentage = (total / net_usd).to_percentage()
            else:
                percentage = '0%'
            stats['location'][name] = {
                'usd_value': total,
                'percentage_of_net_value': percentage,
            }

        for k, v in combined.items():
            if net_usd != FVal(0):
                percentage = (v['usd_value'] / net_usd).to_percentage()
            else:
                percentage = '0%'
            combined[k]['percentage_of_net_value'] = percentage

        result_dict = merge_dicts(combined, stats)

        allowed_to_save = requested_save_data or self.data.should_save_balances(
        )
        if problem_free and allowed_to_save:
            if not timestamp:
                timestamp = Timestamp(int(time.time()))
            self.data.save_balances_data(data=result_dict, timestamp=timestamp)
            log.debug('query_balances data saved')
        else:
            log.debug(
                'query_balances data not saved',
                allowed_to_save=allowed_to_save,
                problem_free=problem_free,
            )

        # After adding it to the saved file we can overlay additional data that
        # is not required to be saved in the history file
        try:
            details = self.accountant.events.details
            for asset, (tax_free_amount, average_buy_value) in details.items():
                if asset not in result_dict:
                    continue

                result_dict[asset]['tax_free_amount'] = tax_free_amount
                result_dict[asset]['average_buy_value'] = average_buy_value

                current_price = result_dict[asset]['usd_value'] / result_dict[
                    asset]['amount']
                if average_buy_value != FVal(0):
                    result_dict[asset]['percent_change'] = (
                        ((current_price - average_buy_value) /
                         average_buy_value) * 100)
                else:
                    result_dict[asset]['percent_change'] = 'INF'

        except AttributeError:
            pass

        return result_dict

    def set_main_currency(self, currency_string: str) -> Tuple[bool, str]:
        """Takes a currency string from the API and sets it as the main currency for rotki

        Returns True and empty string for success and False and error string for error
        """
        try:
            currency = Asset(currency_string)
        except UnknownAsset:
            msg = f'An unknown asset {currency_string} was given for main currency'
            log.critical(msg)
            return False, msg

        if not currency.is_fiat():
            msg = f'A non-fiat asset {currency_string} was given for main currency'
            log.critical(msg)
            return False, msg

        fiat_currency = FiatAsset(currency.identifier)
        with self.lock:
            self.data.set_main_currency(fiat_currency, self.accountant)

        return True, ''

    def set_settings(self, settings: Dict[str, Any]) -> Tuple[bool, str]:
        log.info('Add new settings')

        message = ''
        with self.lock:
            if 'eth_rpc_endpoint' in settings:
                result, msg = self.blockchain.set_eth_rpc_endpoint(
                    settings['eth_rpc_endpoint'])
                if not result:
                    # Don't save it in the DB
                    del settings['eth_rpc_endpoint']
                    message += "\nEthereum RPC endpoint not set: " + msg

            if 'main_currency' in settings:
                given_symbol = settings['main_currency']
                try:
                    main_currency = Asset(given_symbol)
                except UnknownAsset:
                    return False, f'Unknown fiat currency {given_symbol} provided'
                except DeserializationError:
                    return False, 'Non string type given for fiat currency'

                if not main_currency.is_fiat():
                    msg = (
                        f'Provided symbol for main currency {given_symbol} is '
                        f'not a fiat currency')
                    return False, msg

            res, msg = self.accountant.customize(settings)
            if not res:
                message += '\n' + msg
                return False, message

            self.data.set_settings(settings, self.accountant)

            # Always return success here but with a message
            return True, message

    def setup_exchange(
        self,
        name: str,
        api_key: str,
        api_secret: str,
    ) -> Tuple[bool, str]:
        """
        Setup a new exchange with an api key and an api secret

        By default the api keys are always validated unless validate is False.
        """
        is_success, msg = self.exchange_manager.setup_exchange(
            name=name,
            api_key=api_key,
            api_secret=api_secret,
            database=self.data.db,
        )

        if is_success:
            # Success, save the result in the DB
            self.data.db.add_exchange(name, api_key, api_secret)
        return is_success, msg

    def remove_exchange(self, name: str) -> Tuple[bool, str]:
        if not self.exchange_manager.has_exchange(name):
            return False, 'Exchange {} is not registered'.format(name)

        self.exchange_manager.delete_exchange(name)
        # Success, remove it also from the DB
        self.data.db.remove_exchange(name)
        return True, ''

    def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]:
        """Query for frequently changing data"""
        result: Dict[str, Union[bool, Timestamp]] = {}

        if self.user_is_logged_in:
            result[
                'last_balance_save'] = self.data.db.get_last_balance_save_time(
                )
            result['eth_node_connection'] = self.blockchain.ethchain.connected
            result[
                'history_process_start_ts'] = self.accountant.started_processing_timestamp
            result[
                'history_process_current_ts'] = self.accountant.currently_processing_timestamp
        return result

    def shutdown(self) -> None:
        self.logout()
        self.shutdown_event.set()
コード例 #48
0
 def _init(self, maxsize):
     self._size = 0
     self._semaphore = Semaphore(1)
     self._lock = Lock()
     self.task_queue = Queue()
     self._set_maxsize(maxsize)
コード例 #49
0
ファイル: Lock1.py プロジェクト: isoundy000/learn_python
 def __init__(self):
     self._lock = Semaphore()
     self._current = None
     self._count = 0
コード例 #50
0
ファイル: utils.py プロジェクト: binaryflesh/raiden
                if ex.code != 400:
                    raise
                log.debug('Username taken. Continuing')
                continue
    else:
        raise ValueError('Could not register or login!')

    name = encode_hex(signer.sign(client.user_id.encode()))
    user = client.get_user(client.user_id)
    user.set_display_name(name)
    return user


@cached(cache=LRUCache(128),
        key=attrgetter('user_id', 'displayname'),
        lock=Semaphore())
def validate_userid_signature(user: User) -> Optional[Address]:
    """ Validate a userId format and signature on displayName, and return its address"""
    # display_name should be an address in the USERID_RE format
    match = USERID_RE.match(user.user_id)
    if not match:
        return None

    encoded_address = match.group(1)
    address: Address = to_canonical_address(encoded_address)

    try:
        displayname = user.get_display_name()
        recovered = recover(
            data=user.user_id.encode(),
            signature=decode_hex(displayname),
コード例 #51
0
 def __init__(self, web3: Web3, filter_params: dict):
     super().__init__(web3, filter_id=None)
     self.filter_params = filter_params
     self._last_block: int = -1
     self._lock = Semaphore()
コード例 #52
0
class RaidenService(Runnable):
    """ A Raiden node. """

    def __init__(
            self,
            chain: BlockChainService,
            query_start_block: typing.BlockNumber,
            default_registry: TokenNetworkRegistry,
            default_secret_registry: SecretRegistry,
            private_key_bin,
            transport,
            raiden_event_handler,
            config,
            discovery=None,
    ):
        super().__init__()
        if not isinstance(private_key_bin, bytes) or len(private_key_bin) != 32:
            raise ValueError('invalid private_key')

        self.tokennetworkids_to_connectionmanagers = dict()
        self.identifier_to_results: typing.Dict[
            typing.PaymentID,
            AsyncResult,
        ] = dict()

        self.chain: BlockChainService = chain
        self.default_registry = default_registry
        self.query_start_block = query_start_block
        self.default_secret_registry = default_secret_registry
        self.config = config
        self.privkey = private_key_bin
        self.address = privatekey_to_address(private_key_bin)
        self.discovery = discovery

        self.private_key = PrivateKey(private_key_bin)
        self.pubkey = self.private_key.public_key.format(compressed=False)
        self.transport = transport

        self.blockchain_events = BlockchainEvents()
        self.alarm = AlarmTask(chain)
        self.raiden_event_handler = raiden_event_handler

        self.stop_event = Event()
        self.stop_event.set()  # inits as stopped

        self.wal = None
        self.snapshot_group = 0

        # This flag will be used to prevent the service from processing
        # state changes events until we know that pending transactions
        # have been dispatched.
        self.dispatch_events_lock = Semaphore(1)

        self.database_path = config['database_path']
        if self.database_path != ':memory:':
            database_dir = os.path.dirname(config['database_path'])
            os.makedirs(database_dir, exist_ok=True)

            self.database_dir = database_dir
            # Prevent concurrent access to the same db
            self.lock_file = os.path.join(self.database_dir, '.lock')
            self.db_lock = filelock.FileLock(self.lock_file)
        else:
            self.database_path = ':memory:'
            self.database_dir = None
            self.lock_file = None
            self.serialization_file = None
            self.db_lock = None

        self.event_poll_lock = gevent.lock.Semaphore()

    def start(self):
        """ Start the node synchronously. Raises directly if anything went wrong on startup """
        if not self.stop_event.ready():
            raise RuntimeError(f'{self!r} already started')
        self.stop_event.clear()

        if self.database_dir is not None:
            self.db_lock.acquire(timeout=0)
            assert self.db_lock.is_locked

        # start the registration early to speed up the start
        if self.config['transport_type'] == 'udp':
            endpoint_registration_greenlet = gevent.spawn(
                self.discovery.register,
                self.address,
                self.config['transport']['udp']['external_ip'],
                self.config['transport']['udp']['external_port'],
            )

        storage = sqlite.SQLiteStorage(self.database_path, serialize.JSONSerializer())
        self.wal = wal.restore_to_state_change(
            transition_function=node.state_transition,
            storage=storage,
            state_change_identifier='latest',
        )

        if self.wal.state_manager.current_state is None:
            log.debug(
                'No recoverable state available, created inital state',
                node=pex(self.address),
            )
            block_number = self.chain.block_number()

            state_change = ActionInitChain(
                random.Random(),
                block_number,
                self.chain.node_address,
                self.chain.network_id,
            )
            self.wal.log_and_dispatch(state_change)
            payment_network = PaymentNetworkState(
                self.default_registry.address,
                [],  # empty list of token network states as it's the node's startup
            )
            state_change = ContractReceiveNewPaymentNetwork(
                constants.EMPTY_HASH,
                payment_network,
            )
            self.handle_state_change(state_change)

            # On first run Raiden needs to fetch all events for the payment
            # network, to reconstruct all token network graphs and find opened
            # channels
            last_log_block_number = 0
        else:
            # The `Block` state change is dispatched only after all the events
            # for that given block have been processed, filters can be safely
            # installed starting from this position without losing events.
            last_log_block_number = views.block_number(self.wal.state_manager.current_state)
            log.debug(
                'Restored state from WAL',
                last_restored_block=last_log_block_number,
                node=pex(self.address),
            )

            known_networks = views.get_payment_network_identifiers(views.state_from_raiden(self))
            if known_networks and self.default_registry.address not in known_networks:
                configured_registry = pex(self.default_registry.address)
                known_registries = lpex(known_networks)
                raise RuntimeError(
                    f'Token network address mismatch.\n'
                    f'Raiden is configured to use the smart contract '
                    f'{configured_registry}, which conflicts with the current known '
                    f'smart contracts {known_registries}',
                )

        # Clear ref cache & disable caching
        serialize.RaidenJSONDecoder.ref_cache.clear()
        serialize.RaidenJSONDecoder.cache_object_references = False

        # Restore the current snapshot group
        state_change_qty = self.wal.storage.count_state_changes()
        self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT

        # Install the filters using the correct from_block value, otherwise
        # blockchain logs can be lost.
        self.install_all_blockchain_filters(
            self.default_registry,
            self.default_secret_registry,
            last_log_block_number,
        )

        # Complete the first_run of the alarm task and synchronize with the
        # blockchain since the last run.
        #
        # Notes about setup order:
        # - The filters must be polled after the node state has been primed,
        # otherwise the state changes won't have effect.
        # - The alarm must complete its first run  before the transport is started,
        #  to avoid rejecting messages for unknown channels.
        self.alarm.register_callback(self._callback_new_block)

        # alarm.first_run may process some new channel, which would start_health_check_for
        # a partner, that's why transport needs to be already started at this point
        self.transport.start(self)

        self.alarm.first_run()

        chain_state = views.state_from_raiden(self)
        # Dispatch pending transactions
        pending_transactions = views.get_pending_transactions(
            chain_state,
        )
        log.debug(
            'Processing pending transactions',
            num_pending_transactions=len(pending_transactions),
            node=pex(self.address),
        )
        with self.dispatch_events_lock:
            for transaction in pending_transactions:
                try:
                    self.raiden_event_handler.on_raiden_event(self, transaction)
                except RaidenRecoverableError as e:
                    log.error(str(e))
                except InvalidDBData as e:
                    raise
                except RaidenUnrecoverableError as e:
                    if self.config['network_type'] == NetworkType.MAIN:
                        log.error(str(e))
                    else:
                        raise

        self.alarm.start()

        # after transport and alarm is started, send queued messages
        events_queues = views.get_all_messagequeues(chain_state)

        for queue_identifier, event_queue in events_queues.items():
            self.start_health_check_for(queue_identifier.recipient)

            # repopulate identifier_to_results for pending transfers
            for event in event_queue:
                if type(event) == SendDirectTransfer:
                    self.identifier_to_results[event.payment_identifier] = AsyncResult()

                message = message_from_sendevent(event, self.address)
                self.sign(message)
                self.transport.send_async(queue_identifier, message)

        # exceptions on these subtasks should crash the app and bubble up
        self.alarm.link_exception(self.on_error)
        self.transport.link_exception(self.on_error)

        # Health check needs the transport layer
        self.start_neighbours_healthcheck()

        if self.config['transport_type'] == 'udp':
            endpoint_registration_greenlet.get()  # re-raise if exception occurred

        super().start()

    def _run(self):
        """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """
        try:
            self.stop_event.wait()
        except gevent.GreenletExit:  # killed without exception
            self.stop_event.set()
            gevent.killall([self.alarm, self.transport])  # kill children
            raise  # re-raise to keep killed status
        except Exception:
            self.stop()
            raise

    def stop(self):
        """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """
        if self.stop_event.ready():  # not started
            return

        # Needs to come before any greenlets joining
        self.stop_event.set()

        # Filters must be uninstalled after the alarm task has stopped. Since
        # the events are polled by an alarm task callback, if the filters are
        # uninstalled before the alarm task is fully stopped the callback
        # `poll_blockchain_events` will fail.
        #
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        self.transport.stop()
        self.alarm.stop()

        self.transport.join()
        self.alarm.join()

        self.blockchain_events.uninstall_all_event_listeners()

        if self.db_lock is not None:
            self.db_lock.release()

    def add_pending_greenlet(self, greenlet: gevent.Greenlet):
        greenlet.link_exception(self.on_error)

    def __repr__(self):
        return '<{} {}>'.format(self.__class__.__name__, pex(self.address))

    def start_neighbours_healthcheck(self):
        for neighbour in views.all_neighbour_nodes(self.wal.state_manager.current_state):
            if neighbour != ConnectionManager.BOOTSTRAP_ADDR:
                self.start_health_check_for(neighbour)

    def get_block_number(self):
        return views.block_number(self.wal.state_manager.current_state)

    def handle_state_change(self, state_change):
        log.debug('STATE CHANGE', node=pex(self.address), state_change=state_change)

        event_list = self.wal.log_and_dispatch(state_change)

        if self.dispatch_events_lock.locked():
            return []

        for event in event_list:
            log.debug('RAIDEN EVENT', node=pex(self.address), raiden_event=event)

            try:
                self.raiden_event_handler.on_raiden_event(
                    raiden=self,
                    event=event,
                )
            except RaidenRecoverableError as e:
                log.error(str(e))
            except InvalidDBData as e:
                raise
            except RaidenUnrecoverableError as e:
                if self.config['network_type'] == NetworkType.MAIN:
                    log.error(str(e))
                else:
                    raise

        # Take a snapshot every SNAPSHOT_STATE_CHANGES_COUNT
        # TODO: Gather more data about storage requirements
        # and update the value to specify how often we need
        # capturing a snapshot should take place
        new_snapshot_group = self.wal.storage.count_state_changes() // SNAPSHOT_STATE_CHANGES_COUNT
        if new_snapshot_group > self.snapshot_group:
            log.debug(f'Storing snapshot: {new_snapshot_group}')
            self.wal.snapshot()
            self.snapshot_group = new_snapshot_group

        return event_list

    def set_node_network_state(self, node_address, network_state):
        state_change = ActionChangeNodeNetworkState(node_address, network_state)
        self.wal.log_and_dispatch(state_change)

    def start_health_check_for(self, node_address):
        self.transport.start_health_check(node_address)

    def _callback_new_block(self, latest_block):
        """Called once a new block is detected by the alarm task.

        Note:
            This should be called only once per block, otherwise there will be
            duplicated `Block` state changes in the log.

            Therefore this method should be called only once a new block is
            mined with the corresponding block data from the AlarmTask.
        """
        # User facing APIs, which have on-chain side-effects, force polled the
        # blockchain to update the node's state. This force poll is used to
        # provide a consistent view to the user, e.g. a channel open call waits
        # for the transaction to be mined and force polled the event to update
        # the node's state. This pattern introduced a race with the alarm task
        # and the task which served the user request, because the events are
        # returned only once per filter. The lock below is to protect against
        # these races (introduced by the commit
        # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d)
        with self.event_poll_lock:
            latest_block_number = latest_block['number']

            for event in self.blockchain_events.poll_blockchain_events(latest_block_number):
                # These state changes will be procesed with a block_number
                # which is /larger/ than the ChainState's block_number.
                on_blockchain_event(self, event)

            # On restart the Raiden node will re-create the filters with the
            # ethereum node. These filters will have the from_block set to the
            # value of the latest Block state change. To avoid missing events
            # the Block state change is dispatched only after all of the events
            # have been processed.
            #
            # This means on some corner cases a few events may be applied
            # twice, this will happen if the node crashed and some events have
            # been processed but the Block state change has not been
            # dispatched.
            state_change = Block(
                block_number=latest_block_number,
                gas_limit=latest_block['gasLimit'],
                block_hash=bytes(latest_block['hash']),
            )
            self.handle_state_change(state_change)

    def sign(self, message):
        """ Sign message inplace. """
        if not isinstance(message, SignedMessage):
            raise ValueError('{} is not signable.'.format(repr(message)))

        message.sign(self.private_key)

    def install_all_blockchain_filters(
            self,
            token_network_registry_proxy: TokenNetworkRegistry,
            secret_registry_proxy: SecretRegistry,
            from_block: typing.BlockNumber,
    ):
        with self.event_poll_lock:
            node_state = views.state_from_raiden(self)
            token_networks = views.get_token_network_identifiers(
                node_state,
                token_network_registry_proxy.address,
            )

            self.blockchain_events.add_token_network_registry_listener(
                token_network_registry_proxy,
                from_block,
            )
            self.blockchain_events.add_secret_registry_listener(
                secret_registry_proxy,
                from_block,
            )

            for token_network in token_networks:
                token_network_proxy = self.chain.token_network(token_network)
                self.blockchain_events.add_token_network_listener(
                    token_network_proxy,
                    from_block,
                )

    def connection_manager_for_token_network(self, token_network_identifier):
        if not is_binary_address(token_network_identifier):
            raise InvalidAddress('token address is not valid.')

        known_token_networks = views.get_token_network_identifiers(
            views.state_from_raiden(self),
            self.default_registry.address,
        )

        if token_network_identifier not in known_token_networks:
            raise InvalidAddress('token is not registered.')

        manager = self.tokennetworkids_to_connectionmanagers.get(token_network_identifier)

        if manager is None:
            manager = ConnectionManager(self, token_network_identifier)
            self.tokennetworkids_to_connectionmanagers[token_network_identifier] = manager

        return manager

    def leave_all_token_networks(self):
        state_change = ActionLeaveAllNetworks()
        self.wal.log_and_dispatch(state_change)

    def close_and_settle(self):
        log.info('raiden will close and settle all channels now')

        self.leave_all_token_networks()

        connection_managers = [cm for cm in self.tokennetworkids_to_connectionmanagers.values()]

        if connection_managers:
            waiting.wait_for_settle_all_channels(
                self,
                self.alarm.sleep_time,
            )

    def mediated_transfer_async(
            self,
            token_network_identifier: typing.TokenNetworkID,
            amount: typing.TokenAmount,
            target: typing.Address,
            identifier: typing.PaymentID,
    ):
        """ Transfer `amount` between this node and `target`.

        This method will start an asynchronous transfer, the transfer might fail
        or succeed depending on a couple of factors:

            - Existence of a path that can be used, through the usage of direct
              or intermediary channels.
            - Network speed, making the transfer sufficiently fast so it doesn't
              expire.
        """

        async_result = self.start_mediated_transfer(
            token_network_identifier,
            amount,
            target,
            identifier,
        )

        return async_result

    def direct_transfer_async(self, token_network_identifier, amount, target, identifier):
        """ Do a direct transfer with target.

        Direct transfers are non cancellable and non expirable, since these
        transfers are a signed balance proof with the transferred amount
        incremented.

        Because the transfer is non cancellable, there is a level of trust with
        the target. After the message is sent the target is effectively paid
        and then it is not possible to revert.

        The async result will be set to False iff there is no direct channel
        with the target or the payer does not have balance to complete the
        transfer, otherwise because the transfer is non expirable the async
        result *will never be set to False* and if the message is sent it will
        hang until the target node acknowledge the message.

        This transfer should be used as an optimization, since only two packets
        are required to complete the transfer (from the payers perspective),
        whereas the mediated transfer requires 6 messages.
        """

        self.start_health_check_for(target)

        if identifier is None:
            identifier = create_default_identifier()

        direct_transfer = ActionTransferDirect(
            token_network_identifier,
            target,
            identifier,
            amount,
        )

        async_result = AsyncResult()
        self.identifier_to_results[identifier] = async_result

        self.handle_state_change(direct_transfer)

    def start_mediated_transfer(
            self,
            token_network_identifier: typing.TokenNetworkID,
            amount: typing.TokenAmount,
            target: typing.Address,
            identifier: typing.PaymentID,
    ):

        self.start_health_check_for(target)

        if identifier is None:
            identifier = create_default_identifier()

        if identifier in self.identifier_to_results:
            return self.identifier_to_results[identifier]

        async_result = AsyncResult()
        self.identifier_to_results[identifier] = async_result

        secret = random_secret()
        init_initiator_statechange = initiator_init(
            self,
            identifier,
            amount,
            secret,
            token_network_identifier,
            target,
        )

        # Dispatch the state change even if there are no routes to create the
        # wal entry.
        self.handle_state_change(init_initiator_statechange)

        return async_result

    def mediate_mediated_transfer(self, transfer: LockedTransfer):
        init_mediator_statechange = mediator_init(self, transfer)
        self.handle_state_change(init_mediator_statechange)

    def target_mediated_transfer(self, transfer: LockedTransfer):
        self.start_health_check_for(transfer.initiator)
        init_target_statechange = target_init(transfer)
        self.handle_state_change(init_target_statechange)
コード例 #53
0
    def __init__(
        self,
        chain: BlockChainService,
        query_start_block: BlockNumber,
        default_registry: TokenNetworkRegistry,
        default_secret_registry: SecretRegistry,
        private_key_bin,
        transport,
        raiden_event_handler,
        message_handler,
        config,
        discovery=None,
    ):
        super().__init__()
        if not isinstance(private_key_bin,
                          bytes) or len(private_key_bin) != 32:
            raise ValueError('invalid private_key')

        self.tokennetworkids_to_connectionmanagers = dict()
        self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict(
            dict)

        self.chain: BlockChainService = chain
        self.default_registry = default_registry
        self.query_start_block = query_start_block
        self.default_secret_registry = default_secret_registry
        self.config = config
        self.privkey = private_key_bin
        self.address = privatekey_to_address(private_key_bin)
        self.discovery = discovery

        self.private_key = PrivateKey(private_key_bin)
        self.pubkey = self.private_key.public_key.format(compressed=False)
        self.transport = transport

        self.blockchain_events = BlockchainEvents()
        self.alarm = AlarmTask(chain)
        self.raiden_event_handler = raiden_event_handler
        self.message_handler = message_handler

        self.stop_event = Event()
        self.stop_event.set()  # inits as stopped

        self.wal = None
        self.snapshot_group = 0

        # This flag will be used to prevent the service from processing
        # state changes events until we know that pending transactions
        # have been dispatched.
        self.dispatch_events_lock = Semaphore(1)

        self.contract_manager = ContractManager(config['contracts_path'])
        self.database_path = config['database_path']
        if self.database_path != ':memory:':
            database_dir = os.path.dirname(config['database_path'])
            os.makedirs(database_dir, exist_ok=True)

            self.database_dir = database_dir
            # Prevent concurrent access to the same db
            self.lock_file = os.path.join(self.database_dir, '.lock')
            self.db_lock = filelock.FileLock(self.lock_file)
        else:
            self.database_path = ':memory:'
            self.database_dir = None
            self.lock_file = None
            self.serialization_file = None
            self.db_lock = None

        self.event_poll_lock = gevent.lock.Semaphore()
        self.gas_reserve_lock = gevent.lock.Semaphore()
        self.payment_identifier_lock = gevent.lock.Semaphore()
コード例 #54
0
 def __init__(self, *args, **kwargs):
     # The raw (non-gevent) socket, if possible
     self._socket = None
     BaseServer.__init__(self, *args, **kwargs)
     from gevent.lock import Semaphore
     self._writelock = Semaphore()
コード例 #55
0
 def __init__(self) -> None:
     super().__init__()
     self.query_locks_map: Dict[int, Semaphore] = defaultdict(Semaphore)
     # Accessing and writing to the query_locks map also needs to be protected
     self.query_locks_map_lock = Semaphore()
コード例 #56
0
#     for i in range(5):
#         print("I am fun 2 this is %s"%i)
#         gevent.sleep(0)
#
# # fun1()
# # fun2()
#
# t1 = gevent.spawn(fun1)
# t2 = gevent.spawn(fun2)
#
# gevent.joinall([t1,t2])

import gevent
from gevent.lock import Semaphore

sem = Semaphore(1)

def fun1():
    for i in range(5):
        sem.acquire()
        print("I am fun 1 this is %s"%i)
        sem.release()
def fun2():
    for i in range(5):
        sem.acquire()
        print("I am fun 2 this is %s"%i)
        sem.release()

# fun1()
# fun2()
コード例 #57
0
 def __init__(self, name, task):
     self.name = name
     self.task = task
     self.lock = Semaphore(task.max_concurrent)
コード例 #58
0
class Queue(Greenlet):
    """Manages the queue of |Envelope| objects waiting for delivery. This is
    not a standard FIFO queue, a message's place in the queue depends entirely
    on the timestamp of its next delivery attempt.

    :param store: Object implementing :class:`QueueStorage`.
    :param relay: |Relay| object used to attempt message deliveries. If this
                  is not given, no deliveries will be attempted on received
                  messages.
    :param backoff: Function that, given an |Envelope| and number of delivery
                    attempts, will return the number of seconds before the next
                    attempt. If it returns ``None``, the message will be
                    permanently failed. The default backoff function simply
                    returns ``None`` and messages are never retried.
    :param bounce_factory: Function that produces a |Bounce| object given the
                           same parameters as the |Bounce| constructor. If the
                           function returns ``None``, no bounce is delivered.
                           By default, a new |Bounce| is created in every case.
    :param bounce_queue: |Queue| object that will be used for delivering bounce
                         messages. The default is ``self``.
    :param store_pool: Number of simultaneous operations performable against
                       the ``store`` object. Default is unlimited.
    :param relay_pool: Number of simultaneous operations performable against
                       the ``relay`` object. Default is unlimited.

    """
    def __init__(self,
                 store,
                 relay=None,
                 backoff=None,
                 bounce_factory=None,
                 bounce_queue=None,
                 store_pool=None,
                 relay_pool=None):
        super(Queue, self).__init__()
        self.store = store
        self.relay = relay
        self.backoff = backoff or self._default_backoff
        self.bounce_factory = bounce_factory or Bounce
        self.bounce_queue = bounce_queue or self
        self.wake = Event()
        self.queued = []
        self.active_ids = set()
        self.queued_ids = set()
        self.queued_lock = Semaphore(1)
        self.queue_policies = []
        self._use_pool('store_pool', store_pool)
        self._use_pool('relay_pool', relay_pool)

    def add_policy(self, policy):
        """Adds a |QueuePolicy| to be executed before messages are persisted
        to storage.

        :param policy: |QueuePolicy| object to execute.

        """
        if isinstance(policy, QueuePolicy):
            self.queue_policies.append(policy)
        else:
            raise TypeError('Argument not a QueuePolicy.')

    @staticmethod
    def _default_backoff(envelope, attempts):
        pass

    def _run_policies(self, envelope):
        results = [envelope]

        def recurse(current, i):
            try:
                policy = self.queue_policies[i]
            except IndexError:
                return
            ret = policy.apply(current)
            if ret:
                results.remove(current)
                results.extend(ret)
                for env in ret:
                    recurse(env, i + 1)
            else:
                recurse(current, i + 1)

        recurse(envelope, 0)
        return results

    def _use_pool(self, attr, pool):
        if pool is None:
            pass
        elif isinstance(pool, Pool):
            setattr(self, attr, pool)
        else:
            setattr(self, attr, Pool(pool))

    def _pool_run(self, which, func, *args, **kwargs):
        pool = getattr(self, which + '_pool', None)
        if pool:
            ret = pool.spawn(func, *args, **kwargs)
            return ret.get()
        else:
            return func(*args, **kwargs)

    def _pool_imap(self, which, func, *iterables):
        pool = getattr(self, which + '_pool', gevent)
        threads = imap(pool.spawn, repeat(func), *iterables)
        ret = []
        for thread in threads:
            thread.join()
            ret.append(thread.exception or thread.value)
        return ret

    def _pool_spawn(self, which, func, *args, **kwargs):
        pool = getattr(self, which + '_pool', gevent)
        return pool.spawn(func, *args, **kwargs)

    def _add_queued(self, entry):
        timestamp, id = entry
        if id not in self.queued_ids | self.active_ids:
            bisect.insort(self.queued, entry)
            self.queued_ids.add(id)
            self.wake.set()

    def enqueue(self, envelope):
        """Drops a new message in the queue for delivery. The first delivery
        attempt is made immediately (depending on relay pool availability).
        This method is not typically called directly, |Edge| objects use it
        when they receive new messages.

        :param envelope: |Envelope| object to enqueue.
        :returns: Zipped list of envelopes and their respective queue IDs (or
                  thrown :exc:`QueueError` objects).

        """
        now = time.time()
        envelopes = self._run_policies(envelope)
        ids = self._pool_imap('store', self.store.write, envelopes,
                              repeat(now))
        results = list(zip(envelopes, ids))
        for env, id in results:
            if not isinstance(id, BaseException):
                if self.relay:
                    self.active_ids.add(id)
                    self._pool_spawn('relay', self._attempt, id, env, 0)
            elif not isinstance(id, QueueError):
                raise id  # Re-raise exceptions that are not QueueError.
        return results

    def _load_all(self):
        for entry in self.store.load():
            self._add_queued(entry)

    def _remove(self, id):
        self._pool_spawn('store', self.store.remove, id)
        self.queued_ids.discard(id)
        self.active_ids.discard(id)

    def _bounce(self, envelope, reply):
        bounce = self.bounce_factory(envelope, reply)
        if bounce:
            return self.bounce_queue.enqueue(bounce)

    def _perm_fail(self, id, envelope, reply):
        if id is not None:
            self._remove(id)
        if envelope.sender:  # Can't bounce to null-sender.
            self._pool_spawn('bounce', self._bounce, envelope, reply)

    def _split_by_reply(self, envelope, replies):
        if isinstance(replies, Reply):
            return [(replies, envelope)]
        groups = []
        for i, rcpt in enumerate(envelope.recipients):
            for reply, group_env in groups:
                if replies[i] == reply:
                    group_env.recipients.append(rcpt)
                    break
            else:
                group_env = envelope.copy([rcpt])
                groups.append((replies[i], group_env))
        return groups

    def _retry_later(self, id, envelope, replies):
        attempts = self.store.increment_attempts(id)
        wait = self.backoff(envelope, attempts)
        if wait is None:
            for reply, group_env in self._split_by_reply(envelope, replies):
                reply.message += ' (Too many retries)'
                self._perm_fail(None, group_env, reply)
            self._remove(id)
            return False
        else:
            when = time.time() + wait
            self.store.set_timestamp(id, when)
            self.active_ids.discard(id)
            self._add_queued((when, id))
            return True

    def _attempt(self, id, envelope, attempts):
        try:
            results = self.relay._attempt(envelope, attempts)
        except TransientRelayError as e:
            self._pool_spawn('store', self._retry_later, id, envelope, e.reply)
        except PermanentRelayError as e:
            self._perm_fail(id, envelope, e.reply)
        except Exception as e:
            log_exception(__name__)
            reply = Reply('450', '4.0.0 Unhandled delivery error: ' + str(e))
            self._pool_spawn('store', self._retry_later, id, envelope, reply)
            raise
        else:
            if isinstance(results, collections.Mapping):
                self._handle_partial_relay(id, envelope, attempts, results)
            elif isinstance(results, collections.Sequence):
                results = dict(zip(envelope.recipients, results))
                self._handle_partial_relay(id, envelope, attempts, results)
            else:
                self._remove(id)

    def _handle_partial_relay(self, id, envelope, attempts, results):
        delivered = set()
        tempfails = []
        permfails = []
        for rcpt, rcpt_res in results.items():
            if rcpt_res is None or isinstance(rcpt_res, Reply):
                delivered.add(envelope.recipients.index(rcpt))
            elif isinstance(rcpt_res, PermanentRelayError):
                delivered.add(envelope.recipients.index(rcpt))
                permfails.append((rcpt, rcpt_res.reply))
            elif isinstance(rcpt_res, TransientRelayError):
                tempfails.append((rcpt, rcpt_res.reply))
        if permfails:
            rcpts, replies = zip(*permfails)
            fail_env = envelope.copy(rcpts)
            for reply, group_env in self._split_by_reply(fail_env, replies):
                self._perm_fail(None, group_env, reply)
        if tempfails:
            rcpts, replies = zip(*tempfails)
            fail_env = envelope.copy(rcpts)
            if not self._retry_later(id, fail_env, replies):
                return
        else:
            self.store.remove(id)
            return
        self.store.set_recipients_delivered(id, delivered)

    def _dequeue(self, id):
        try:
            envelope, attempts = self.store.get(id)
        except KeyError:
            return
        self.active_ids.add(id)
        self._pool_spawn('relay', self._attempt, id, envelope, attempts)

    def _check_ready(self, now):
        last_i = 0
        for i, entry in enumerate(self.queued):
            timestamp, entry_id = entry
            if now >= timestamp:
                self._pool_spawn('store', self._dequeue, entry_id)
                last_i = i + 1
            else:
                break
        if last_i > 0:
            self.queued = self.queued[last_i:]
            self.queued_ids = set([id for _, id in self.queued])

    def _wait_store(self):
        while True:
            try:
                for entry in self.store.wait():
                    self._add_queued(entry)
            except NotImplementedError:
                return

    def _wait_ready(self, now):
        try:
            first = self.queued[0]
        except IndexError:
            self.wake.wait()
            self.wake.clear()
            return
        first_timestamp = first[0]
        if first_timestamp > now:
            self.wake.wait(first_timestamp - now)
            self.wake.clear()

    def flush(self):
        """Attempts to immediately flush all messages waiting in the queue,
        regardless of their retry timers.

        .. warning::

           This can be a very expensive operation, use with care.

        """
        self.wake.set()
        self.wake.clear()
        self.queued_lock.acquire()
        try:
            for entry in self.queued:
                self._pool_spawn('store', self._dequeue, entry[1])
            self.queued = []
        finally:
            self.queued_lock.release()

    def kill(self):
        """This method is used by |Queue| and |Queue|-like objects to properly
        end any associated services (such as running :class:`~gevent.Greenlet`
        threads) and close resources.

        """
        super(Queue, self).kill()

    def _run(self):
        if not self.relay:
            return
        self._pool_spawn('store', self._load_all)
        self._pool_spawn('store', self._wait_store)
        while True:
            self.queued_lock.acquire()
            try:
                now = time.time()
                self._check_ready(now)
                self._wait_ready(now)
            finally:
                self.queued_lock.release()
コード例 #59
0
class RaidenService(Runnable):
    """ A Raiden node. """

    def __init__(
            self,
            chain: BlockChainService,
            query_start_block: BlockNumber,
            default_registry: TokenNetworkRegistry,
            default_secret_registry: SecretRegistry,
            transport,
            raiden_event_handler,
            message_handler,
            config,
            discovery=None,
    ):
        super().__init__()
        self.tokennetworkids_to_connectionmanagers = dict()
        self.targets_to_identifiers_to_statuses: StatusesDict = defaultdict(dict)

        self.chain: BlockChainService = chain
        self.default_registry = default_registry
        self.query_start_block = query_start_block
        self.default_secret_registry = default_secret_registry
        self.config = config

        self.signer: Signer = LocalSigner(self.chain.client.privkey)
        self.address = self.signer.address
        self.discovery = discovery
        self.transport = transport

        self.blockchain_events = BlockchainEvents()
        self.alarm = AlarmTask(chain)
        self.raiden_event_handler = raiden_event_handler
        self.message_handler = message_handler

        self.stop_event = Event()
        self.stop_event.set()  # inits as stopped

        self.wal = None
        self.snapshot_group = 0

        # This flag will be used to prevent the service from processing
        # state changes events until we know that pending transactions
        # have been dispatched.
        self.dispatch_events_lock = Semaphore(1)

        self.contract_manager = ContractManager(config['contracts_path'])
        self.database_path = config['database_path']
        if self.database_path != ':memory:':
            database_dir = os.path.dirname(config['database_path'])
            os.makedirs(database_dir, exist_ok=True)

            self.database_dir = database_dir

            # Two raiden processes must not write to the same database, even
            # though the database itself may be consistent. If more than one
            # nodes writes state changes to the same WAL there are no
            # guarantees about recovery, this happens because during recovery
            # the WAL replay can not be deterministic.
            self.lock_file = os.path.join(self.database_dir, '.lock')
            self.db_lock = filelock.FileLock(self.lock_file)
        else:
            self.database_path = ':memory:'
            self.database_dir = None
            self.lock_file = None
            self.serialization_file = None
            self.db_lock = None

        self.event_poll_lock = gevent.lock.Semaphore()
        self.gas_reserve_lock = gevent.lock.Semaphore()
        self.payment_identifier_lock = gevent.lock.Semaphore()

    def start(self):
        """ Start the node synchronously. Raises directly if anything went wrong on startup """
        if not self.stop_event.ready():
            raise RuntimeError(f'{self!r} already started')
        self.stop_event.clear()

        if self.database_dir is not None:
            self.db_lock.acquire(timeout=0)
            assert self.db_lock.is_locked

        # start the registration early to speed up the start
        if self.config['transport_type'] == 'udp':
            endpoint_registration_greenlet = gevent.spawn(
                self.discovery.register,
                self.address,
                self.config['transport']['udp']['external_ip'],
                self.config['transport']['udp']['external_port'],
            )

        self.maybe_upgrade_db()

        storage = sqlite.SerializedSQLiteStorage(
            database_path=self.database_path,
            serializer=serialize.JSONSerializer(),
        )
        storage.log_run()
        self.wal = wal.restore_to_state_change(
            transition_function=node.state_transition,
            storage=storage,
            state_change_identifier='latest',
        )

        if self.wal.state_manager.current_state is None:
            log.debug(
                'No recoverable state available, created inital state',
                node=pex(self.address),
            )
            # On first run Raiden needs to fetch all events for the payment
            # network, to reconstruct all token network graphs and find opened
            # channels
            last_log_block_number = self.query_start_block

            state_change = ActionInitChain(
                random.Random(),
                last_log_block_number,
                self.chain.node_address,
                self.chain.network_id,
            )
            self.handle_state_change(state_change)

            payment_network = PaymentNetworkState(
                self.default_registry.address,
                [],  # empty list of token network states as it's the node's startup
            )
            state_change = ContractReceiveNewPaymentNetwork(
                constants.EMPTY_HASH,
                payment_network,
                last_log_block_number,
            )
            self.handle_state_change(state_change)
        else:
            # The `Block` state change is dispatched only after all the events
            # for that given block have been processed, filters can be safely
            # installed starting from this position without losing events.
            last_log_block_number = views.block_number(self.wal.state_manager.current_state)
            log.debug(
                'Restored state from WAL',
                last_restored_block=last_log_block_number,
                node=pex(self.address),
            )

            known_networks = views.get_payment_network_identifiers(views.state_from_raiden(self))
            if known_networks and self.default_registry.address not in known_networks:
                configured_registry = pex(self.default_registry.address)
                known_registries = lpex(known_networks)
                raise RuntimeError(
                    f'Token network address mismatch.\n'
                    f'Raiden is configured to use the smart contract '
                    f'{configured_registry}, which conflicts with the current known '
                    f'smart contracts {known_registries}',
                )

        # Restore the current snapshot group
        state_change_qty = self.wal.storage.count_state_changes()
        self.snapshot_group = state_change_qty // SNAPSHOT_STATE_CHANGES_COUNT

        # Install the filters using the correct from_block value, otherwise
        # blockchain logs can be lost.
        self.install_all_blockchain_filters(
            self.default_registry,
            self.default_secret_registry,
            last_log_block_number,
        )

        # Complete the first_run of the alarm task and synchronize with the
        # blockchain since the last run.
        #
        # Notes about setup order:
        # - The filters must be polled after the node state has been primed,
        # otherwise the state changes won't have effect.
        # - The alarm must complete its first run before the transport is started,
        #   to reject messages for closed/settled channels.
        self.alarm.register_callback(self._callback_new_block)
        with self.dispatch_events_lock:
            self.alarm.first_run(last_log_block_number)

        chain_state = views.state_from_raiden(self)
        self._initialize_transactions_queues(chain_state)
        self._initialize_whitelists(chain_state)

        # send messages in queue before starting transport,
        # this is necessary to avoid a race where, if the transport is started
        # before the messages are queued, actions triggered by it can cause new
        # messages to be enqueued before these older ones
        self._initialize_messages_queues(chain_state)

        # The transport must not ever be started before the alarm task's
        # `first_run()` has been, because it's this method which synchronizes the
        # node with the blockchain, including the channel's state (if the channel
        # is closed on-chain new messages must be rejected, which will not be the
        # case if the node is not synchronized)
        self.transport.start(
            raiden_service=self,
            message_handler=self.message_handler,
            prev_auth_data=chain_state.last_transport_authdata,
        )

        # First run has been called above!
        self.alarm.start()

        # exceptions on these subtasks should crash the app and bubble up
        self.alarm.link_exception(self.on_error)
        self.transport.link_exception(self.on_error)

        # Health check needs the transport layer
        self.start_neighbours_healthcheck(chain_state)

        if self.config['transport_type'] == 'udp':
            endpoint_registration_greenlet.get()  # re-raise if exception occurred

        log.debug('Raiden Service started', node=pex(self.address))
        super().start()

    def _run(self, *args, **kwargs):  # pylint: disable=method-hidden
        """ Busy-wait on long-lived subtasks/greenlets, re-raise if any error occurs """
        try:
            self.stop_event.wait()
        except gevent.GreenletExit:  # killed without exception
            self.stop_event.set()
            gevent.killall([self.alarm, self.transport])  # kill children
            raise  # re-raise to keep killed status
        except Exception:
            self.stop()
            raise

    def stop(self):
        """ Stop the node gracefully. Raise if any stop-time error occurred on any subtask """
        if self.stop_event.ready():  # not started
            return

        # Needs to come before any greenlets joining
        self.stop_event.set()

        # Filters must be uninstalled after the alarm task has stopped. Since
        # the events are polled by an alarm task callback, if the filters are
        # uninstalled before the alarm task is fully stopped the callback
        # `poll_blockchain_events` will fail.
        #
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        self.transport.stop()
        self.alarm.stop()

        self.transport.join()
        self.alarm.join()

        self.blockchain_events.uninstall_all_event_listeners()

        if self.db_lock is not None:
            self.db_lock.release()

        log.debug('Raiden Service stopped', node=pex(self.address))

    def add_pending_greenlet(self, greenlet: gevent.Greenlet):
        greenlet.link_exception(self.on_error)

    def __repr__(self):
        return '<{} {}>'.format(self.__class__.__name__, pex(self.address))

    def start_neighbours_healthcheck(self, chain_state: ChainState):
        for neighbour in views.all_neighbour_nodes(chain_state):
            if neighbour != ConnectionManager.BOOTSTRAP_ADDR:
                self.start_health_check_for(neighbour)

    def get_block_number(self) -> BlockNumber:
        return views.block_number(self.wal.state_manager.current_state)

    def on_message(self, message: Message):
        self.message_handler.on_message(self, message)

    def handle_state_change(self, state_change: StateChange):
        log.debug(
            'State change',
            node=pex(self.address),
            state_change=_redact_secret(serialize.JSONSerializer.serialize(state_change)),
        )

        event_list = self.wal.log_and_dispatch(state_change)

        if self.dispatch_events_lock.locked():
            return []

        for event in event_list:
            log.debug(
                'Raiden event',
                node=pex(self.address),
                raiden_event=_redact_secret(serialize.JSONSerializer.serialize(event)),
            )

            try:
                self.raiden_event_handler.on_raiden_event(
                    raiden=self,
                    event=event,
                )
            except RaidenRecoverableError as e:
                log.error(str(e))
            except InvalidDBData:
                raise
            except RaidenUnrecoverableError as e:
                log_unrecoverable = (
                    self.config['environment_type'] == Environment.PRODUCTION and
                    not self.config['unrecoverable_error_should_crash']
                )
                if log_unrecoverable:
                    log.error(str(e))
                else:
                    raise

        # Take a snapshot every SNAPSHOT_STATE_CHANGES_COUNT
        # TODO: Gather more data about storage requirements
        # and update the value to specify how often we need
        # capturing a snapshot should take place
        new_snapshot_group = self.wal.storage.count_state_changes() // SNAPSHOT_STATE_CHANGES_COUNT
        if new_snapshot_group > self.snapshot_group:
            log.debug('Storing snapshot', snapshot_id=new_snapshot_group)
            self.wal.snapshot()
            self.snapshot_group = new_snapshot_group

        return event_list

    def set_node_network_state(self, node_address: Address, network_state: str):
        state_change = ActionChangeNodeNetworkState(node_address, network_state)
        self.handle_state_change(state_change)

    def start_health_check_for(self, node_address: Address):
        # This function is a noop during initialization. It can be called
        # through the alarm task while polling for new channel events.  The
        # healthcheck will be started by self.start_neighbours_healthcheck()
        if self.transport:
            self.transport.start_health_check(node_address)

    def _callback_new_block(self, latest_block: Dict):
        """Called once a new block is detected by the alarm task.

        Note:
            This should be called only once per block, otherwise there will be
            duplicated `Block` state changes in the log.

            Therefore this method should be called only once a new block is
            mined with the corresponding block data from the AlarmTask.
        """
        # User facing APIs, which have on-chain side-effects, force polled the
        # blockchain to update the node's state. This force poll is used to
        # provide a consistent view to the user, e.g. a channel open call waits
        # for the transaction to be mined and force polled the event to update
        # the node's state. This pattern introduced a race with the alarm task
        # and the task which served the user request, because the events are
        # returned only once per filter. The lock below is to protect against
        # these races (introduced by the commit
        # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d)
        with self.event_poll_lock:
            latest_block_number = latest_block['number']
            confirmation_blocks = self.config['blockchain']['confirmation_blocks']
            confirmed_block_number = latest_block_number - confirmation_blocks
            confirmed_block = self.chain.client.web3.eth.getBlock(confirmed_block_number)

            # handle testing private chains
            confirmed_block_number = max(GENESIS_BLOCK_NUMBER, confirmed_block_number)

            for event in self.blockchain_events.poll_blockchain_events(confirmed_block_number):
                # These state changes will be procesed with a block_number
                # which is /larger/ than the ChainState's block_number.
                on_blockchain_event(self, event)

            # On restart the Raiden node will re-create the filters with the
            # ethereum node. These filters will have the from_block set to the
            # value of the latest Block state change. To avoid missing events
            # the Block state change is dispatched only after all of the events
            # have been processed.
            #
            # This means on some corner cases a few events may be applied
            # twice, this will happen if the node crashed and some events have
            # been processed but the Block state change has not been
            # dispatched.
            state_change = Block(
                block_number=confirmed_block_number,
                gas_limit=confirmed_block['gasLimit'],
                block_hash=bytes(confirmed_block['hash']),
            )
            self.handle_state_change(state_change)

    def _register_payment_status(
            self,
            target: TargetAddress,
            identifier: PaymentID,
            balance_proof: BalanceProofUnsignedState,
    ):
        with self.payment_identifier_lock:
            self.targets_to_identifiers_to_statuses[target][identifier] = PaymentStatus(
                payment_identifier=identifier,
                amount=balance_proof.transferred_amount,
                token_network_identifier=balance_proof.token_network_identifier,
                payment_done=AsyncResult(),
            )

    def _initialize_transactions_queues(self, chain_state: ChainState):
        pending_transactions = views.get_pending_transactions(chain_state)

        log.debug(
            'Processing pending transactions',
            num_pending_transactions=len(pending_transactions),
            node=pex(self.address),
        )

        with self.dispatch_events_lock:
            for transaction in pending_transactions:
                try:
                    self.raiden_event_handler.on_raiden_event(self, transaction)
                except RaidenRecoverableError as e:
                    log.error(str(e))
                except InvalidDBData:
                    raise
                except RaidenUnrecoverableError as e:
                    log_unrecoverable = (
                        self.config['environment_type'] == Environment.PRODUCTION and
                        not self.config['unrecoverable_error_should_crash']
                    )
                    if log_unrecoverable:
                        log.error(str(e))
                    else:
                        raise

    def _initialize_messages_queues(self, chain_state: ChainState):
        """ Push the queues to the transport and populate
        targets_to_identifiers_to_statuses.
        """
        events_queues = views.get_all_messagequeues(chain_state)

        for queue_identifier, event_queue in events_queues.items():
            self.start_health_check_for(queue_identifier.recipient)

            for event in event_queue:
                is_initiator = (
                    type(event) == SendLockedTransfer and
                    event.transfer.initiator == self.address
                )

                if is_initiator:
                    self._register_payment_status(
                        target=event.transfer.target,
                        identifier=event.transfer.payment_identifier,
                        balance_proof=event.transfer.balance_proof,
                    )

                message = message_from_sendevent(event, self.address)
                self.sign(message)
                self.transport.send_async(queue_identifier, message)

    def _initialize_whitelists(self, chain_state: ChainState):
        """ Whitelist neighbors and mediated transfer targets on transport """

        for neighbour in views.all_neighbour_nodes(chain_state):
            if neighbour == ConnectionManager.BOOTSTRAP_ADDR:
                continue
            self.transport.whitelist(neighbour)

        events_queues = views.get_all_messagequeues(chain_state)

        for event_queue in events_queues.values():
            for event in event_queue:
                is_initiator = (
                    type(event) == SendLockedTransfer and
                    event.transfer.initiator == self.address
                )
                if is_initiator:
                    self.transport.whitelist(address=event.transfer.target)

    def sign(self, message: Message):
        """ Sign message inplace. """
        if not isinstance(message, SignedMessage):
            raise ValueError('{} is not signable.'.format(repr(message)))

        message.sign(self.signer)

    def install_all_blockchain_filters(
            self,
            token_network_registry_proxy: TokenNetworkRegistry,
            secret_registry_proxy: SecretRegistry,
            from_block: BlockNumber,
    ):
        with self.event_poll_lock:
            node_state = views.state_from_raiden(self)
            token_networks = views.get_token_network_identifiers(
                node_state,
                token_network_registry_proxy.address,
            )

            self.blockchain_events.add_token_network_registry_listener(
                token_network_registry_proxy=token_network_registry_proxy,
                contract_manager=self.contract_manager,
                from_block=from_block,
            )
            self.blockchain_events.add_secret_registry_listener(
                secret_registry_proxy=secret_registry_proxy,
                contract_manager=self.contract_manager,
                from_block=from_block,
            )

            for token_network in token_networks:
                token_network_proxy = self.chain.token_network(
                    TokenNetworkAddress(token_network),
                )
                self.blockchain_events.add_token_network_listener(
                    token_network_proxy=token_network_proxy,
                    contract_manager=self.contract_manager,
                    from_block=from_block,
                )

    def connection_manager_for_token_network(
            self,
            token_network_identifier: TokenNetworkID,
    ) -> ConnectionManager:
        if not is_binary_address(token_network_identifier):
            raise InvalidAddress('token address is not valid.')

        known_token_networks = views.get_token_network_identifiers(
            views.state_from_raiden(self),
            self.default_registry.address,
        )

        if token_network_identifier not in known_token_networks:
            raise InvalidAddress('token is not registered.')

        manager = self.tokennetworkids_to_connectionmanagers.get(token_network_identifier)

        if manager is None:
            manager = ConnectionManager(self, token_network_identifier)
            self.tokennetworkids_to_connectionmanagers[token_network_identifier] = manager

        return manager

    def mediated_transfer_async(
            self,
            token_network_identifier: TokenNetworkID,
            amount: TokenAmount,
            target: TargetAddress,
            identifier: PaymentID,
    ) -> PaymentStatus:
        """ Transfer `amount` between this node and `target`.

        This method will start an asynchronous transfer, the transfer might fail
        or succeed depending on a couple of factors:

            - Existence of a path that can be used, through the usage of direct
              or intermediary channels.
            - Network speed, making the transfer sufficiently fast so it doesn't
              expire.
        """

        secret = random_secret()
        payment_status = self.start_mediated_transfer_with_secret(
            token_network_identifier,
            amount,
            target,
            identifier,
            secret,
        )

        return payment_status

    def start_mediated_transfer_with_secret(
            self,
            token_network_identifier: TokenNetworkID,
            amount: TokenAmount,
            target: TargetAddress,
            identifier: PaymentID,
            secret: Secret,
    ) -> PaymentStatus:

        secret_hash = sha3(secret)
        # LEFTODO: Supply a proper block id
        secret_registered = self.default_secret_registry.check_registered(
            secrethash=secret_hash,
            block_identifier='latest',
        )
        if secret_registered:
            raise RaidenUnrecoverableError(
                f'Attempted to initiate a locked transfer with secrethash {pex(secret_hash)}.'
                f' That secret is already registered onchain.',
            )

        self.start_health_check_for(Address(target))

        if identifier is None:
            identifier = create_default_identifier()

        with self.payment_identifier_lock:
            payment_status = self.targets_to_identifiers_to_statuses[target].get(identifier)
            if payment_status:
                payment_status_matches = payment_status.matches(
                    token_network_identifier,
                    amount,
                )
                if not payment_status_matches:
                    raise PaymentConflict(
                        'Another payment with the same id is in flight',
                    )

                return payment_status

            payment_status = PaymentStatus(
                payment_identifier=identifier,
                amount=amount,
                token_network_identifier=token_network_identifier,
                payment_done=AsyncResult(),
                secret=secret,
                secret_hash=secret_hash,
            )
            self.targets_to_identifiers_to_statuses[target][identifier] = payment_status

        init_initiator_statechange = initiator_init(
            raiden=self,
            transfer_identifier=identifier,
            transfer_amount=amount,
            transfer_secret=secret,
            token_network_identifier=token_network_identifier,
            target_address=target,
        )

        # Dispatch the state change even if there are no routes to create the
        # wal entry.
        self.handle_state_change(init_initiator_statechange)

        return payment_status

    def mediate_mediated_transfer(self, transfer: LockedTransfer):
        init_mediator_statechange = mediator_init(self, transfer)
        self.handle_state_change(init_mediator_statechange)

    def target_mediated_transfer(self, transfer: LockedTransfer):
        self.start_health_check_for(transfer.initiator)
        init_target_statechange = target_init(transfer)
        self.handle_state_change(init_target_statechange)

    def maybe_upgrade_db(self):
        manager = UpgradeManager(db_filename=self.database_path)
        manager.run()
コード例 #60
0
 def __init__(self, url=None, ie_info=None, *args, **kwargs):
     super(YoutubeDLInput, self).__init__(None, *args, **kwargs)
     self._url = url
     self._ie_info = ie_info
     self._info = None
     self._info_lock = Semaphore()