Beispiel #1
0
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests=100, *args, **kwargs):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')

        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url, *args, **kwargs)

    def async_send(self, data, headers, success_cb, failure_cb):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(
            super(GeventedHTTPTransport, self).send, data, headers
        ).link(lambda x: self._done(x, success_cb, failure_cb))

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.exception)
Beispiel #2
0
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):

    scheme = ["gevent+http", "gevent+https"]

    def __init__(self, parsed_url, maximum_outstanding_requests=100):
        if not has_gevent:
            raise ImportError("GeventedHTTPTransport requires gevent.")
        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split("+", 1)[-1]

    def async_send(self, data, headers, success_cb, failure_cb):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(super(GeventedHTTPTransport, self).send, data, headers).link(
            lambda x: self._done(x, success_cb, failure_cb)
        )

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.value)
Beispiel #3
0
class Pool(Group):
    def __init__(self, size=None, greenlet_class=None):
        if size is not None and size < 0:
            raise ValueError("size must not be negative: %r" % (size,))
        Group.__init__(self)
        self.size = size
        if greenlet_class is not None:
            self.greenlet_class = greenlet_class
        if size is None:
            self._semaphore = DummySemaphore()
        else:
            self._semaphore = Semaphore(size)

    def wait_available(self):
        self._semaphore.wait()

    def full(self):
        return self.free_count() <= 0

    def free_count(self):
        if self.size is None:
            return 1
        return max(0, self.size - len(self))

    def add(self, greenlet):
        self._semaphore.acquire()
        try:
            Group.add(self, greenlet)
        except:
            self._semaphore.release()
            raise

    def _discard(self, greenlet):
        Group._discard(self, greenlet)
        self._semaphore.release()
Beispiel #4
0
class GeventedHTTPTransport(HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests=100):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')
        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split('+', 1)[-1]

    def send(self, data, headers):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(
            super(GeventedHTTPTransport, self).send, data,
            headers).link(self._done, self)

    def _done(self, *args):
        self._lock.release()
Beispiel #5
0
class _RPCConnection(object):
    def __init__(self, sock, use_greenlets):
        self._sock = sock
        self.use_greenlets = use_greenlets
        if use_greenlets:
            from gevent.lock import Semaphore

            self._send_lock = Semaphore()

    def recv(self, buf_size):
        return self._sock.recv(buf_size)

    def send(self, msg):
        if self.use_greenlets:
            self._send_lock.acquire()
            try:
                self._sock.sendall(msg)
            finally:
                self._send_lock.release()
        else:
            try:
                self._sock.sendall(msg)
            finally:
                pass

    def __del__(self):
        try:
            self._sock.close()
        except:
            pass
Beispiel #6
0
class GeventedHTTPTransport(HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests=100):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')
        self._lock = Semaphore(maximum_outstanding_requests)

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split('+', 1)[-1]

    def send(self, data, headers):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(super(GeventedHTTPTransport, self).send, data, headers).link(self._done, self)

    def _done(self, *args):
        self._lock.release()
Beispiel #7
0
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, maximum_outstanding_requests=100, *args, **kwargs):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')

        self._lock = Semaphore(maximum_outstanding_requests)

        super().__init__(*args, **kwargs)

    def async_send(self, url, data, headers, success_cb, failure_cb):
        """
        Spawn an async request to a remote webserver.
        """
        # this can be optimized by making a custom self.send that does not
        # read the response since we don't use it.
        self._lock.acquire()
        return gevent.spawn(
            super().send, url, data, headers
        ).link(lambda x: self._done(x, success_cb, failure_cb))

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.exception)
Beispiel #8
0
class Pool(Group):
    def __init__(self, size=None, greenlet_class=None):
        if size is not None and size < 0:
            raise ValueError('size must not be negative: %r' % (size, ))
        Group.__init__(self)
        self.size = size
        if greenlet_class is not None:
            self.greenlet_class = greenlet_class
        if size is None:
            self._semaphore = DummySemaphore()
        else:
            self._semaphore = Semaphore(size)

    def wait_available(self):
        self._semaphore.wait()

    def full(self):
        return self.free_count() <= 0

    def free_count(self):
        if self.size is None:
            return 1
        return max(0, self.size - len(self))

    def add(self, greenlet):
        self._semaphore.acquire()
        try:
            Group.add(self, greenlet)
        except:
            self._semaphore.release()
            raise

    def _discard(self, greenlet):
        Group._discard(self, greenlet)
        self._semaphore.release()
 def test_release_twice(self):
     s = Semaphore()
     result = []
     s.rawlink(lambda s: result.append('a'))
     s.release()
     s.rawlink(lambda s: result.append('b'))
     s.release()
     gevent.sleep(0.001)
     self.assertEqual(result, ['a', 'b'])
Beispiel #10
0
 def test_release_twice(self):
     s = Semaphore()
     result = []
     s.rawlink(lambda s: result.append('a'))
     s.release()
     s.rawlink(lambda s: result.append('b'))
     s.release()
     gevent.sleep(0.001)
     self.assertEqual(result, ['a', 'b'])
Beispiel #11
0
 def test_release_twice(self):
     s = Semaphore()
     result = []
     s.rawlink(lambda s: result.append('a'))
     s.release()
     s.rawlink(lambda s: result.append('b'))
     s.release()
     gevent.sleep(0.001)
     # The order, though, is not guaranteed.
     self.assertEqual(sorted(result), ['a', 'b'])
Beispiel #12
0
 def test_release_twice(self):
     s = Semaphore()
     result = []
     s.rawlink(lambda s: result.append('a'))
     s.release()
     s.rawlink(lambda s: result.append('b'))
     s.release()
     gevent.sleep(0.001)
     # The order, though, is not guaranteed.
     self.assertEqual(sorted(result), ['a', 'b'])
Beispiel #13
0
class ModLogPump(object):
    def __init__(self, channel, max_actions, action_time):
        self.channel = channel
        self.action_time = action_time

        self._have = Event()
        self._buffer = []
        self._lock = Semaphore(max_actions)
        self._emitter = gevent.spawn(self._emit_loop)

    def _get_next_message(self):
        data = ''

        while self._buffer:
            payload = self._buffer.pop(0)
            if len(data) + len(payload) > 2000:
                break
            data += '\n'
            data += payload

        return data

    def _emit_loop(self):
        while True:
            self._have.wait()

            try:
                self._emit()
            except APIException as e:
                # If send message is disabled, backoff (we'll drop events but
                #  thats ok)
                if e.code == 40004:
                    gevent.sleep(5)

            if not len(self._buffer):
                self._have.clear()

    def _emit(self):
        self._lock.acquire()
        msg = self._get_next_message()
        if not msg:
            self._lock.release()
            return
        self.channel.send_message(msg)
        gevent.spawn(self._emit_unlock)

    def _emit_unlock(self):
        gevent.sleep(self.action_time)
        self._lock.release()

    def add_message(self, payload):
        self._buffer.append(payload)
        self._have.set()
Beispiel #14
0
class FifoQueue:
    def __init__(self):
        self._queue = deque()
        self._semaphore = Semaphore(0)

    def __len__(self):
        return len(self._queue)

    def push(self, request):
        self._queue.append(request)
        self._semaphore.release()

    def pop(self):
        self._semaphore.acquire()
        return self._queue.popleft()
Beispiel #15
0
class GeventConnectionPool(psycopg2.pool.AbstractConnectionPool):
    def __init__(self, minconn, maxconn, *args, **kwargs):
        self.semaphore = Semaphore(maxconn)
        psycopg2.pool.AbstractConnectionPool.__init__(self, minconn, maxconn, *args, **kwargs)

    def getconn(self, *args, **kwargs):
        self.semaphore.acquire()
        return self._getconn(*args, **kwargs)

    def putconn(self, *args, **kwargs):
        self._putconn(*args, **kwargs)
        self.semaphore.release()

    # Not sure what to do about this one...
    closeall = psycopg2.pool.AbstractConnectionPool._closeall
def test_renamed_label_refresh(db, default_account, thread, message, imapuid,
                               folder, mock_imapclient, monkeypatch):
    # Check that imapuids see their labels refreshed after running
    # the LabelRenameHandler.
    msg_uid = imapuid.msg_uid
    uid_dict = {msg_uid: GmailFlags((), ('stale label', ), ('23', ))}

    update_metadata(default_account.id, folder.id, folder.canonical_name,
                    uid_dict, db.session)

    new_flags = {
        msg_uid: {
            'FLAGS': ('\\Seen', ),
            'X-GM-LABELS': ('new label', ),
            'MODSEQ': ('23', )
        }
    }
    mock_imapclient._data['[Gmail]/All mail'] = new_flags

    mock_imapclient.add_folder_data(folder.name, new_flags)

    monkeypatch.setattr(MockIMAPClient, 'search', lambda x, y: [msg_uid])

    semaphore = Semaphore(value=1)

    rename_handler = LabelRenameHandler(default_account.id,
                                        default_account.namespace.id,
                                        'new label', semaphore)

    # Acquire the semaphore to check that LabelRenameHandlers block if
    # the semaphore is in-use.
    semaphore.acquire()
    rename_handler.start()

    # Wait 10 secs and check that the data hasn't changed.
    gevent.sleep(10)

    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == 'stale label'
    semaphore.release()
    rename_handler.join()

    db.session.refresh(imapuid)
    # Now check that the label got updated.
    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == 'new label'
Beispiel #17
0
class JobSharedLock(object):
    """
    Shared lock for jobs.
    Each job method can specify a lock which will be shared
    among all calls for that job and only one job can run at a time
    for this lock.
    """

    def __init__(self, queue, name):
        self.queue = queue
        self.name = name
        self.jobs = []
        self.semaphore = Semaphore()

    def add_job(self, job):
        self.jobs.append(job)

    def get_jobs(self):
        return self.jobs

    def remove_job(self, job):
        self.jobs.remove(job)

    def locked(self):
        return self.semaphore.locked()

    def acquire(self):
        return self.semaphore.acquire()

    def release(self):
        return self.semaphore.release()
Beispiel #18
0
class GeventConnectionPool(psycopg2.pool.AbstractConnectionPool):
    def __init__(self, minconn, maxconn, *args, **kwargs):
        self.semaphore = Semaphore(maxconn)
        psycopg2.pool.AbstractConnectionPool.__init__(self, minconn, maxconn,
                                                      *args, **kwargs)

    def getconn(self, *args, **kwargs):
        self.semaphore.acquire()
        return self._getconn(*args, **kwargs)

    def putconn(self, *args, **kwargs):
        self._putconn(*args, **kwargs)
        self.semaphore.release()

    # Not sure what to do about this one...
    closeall = psycopg2.pool.AbstractConnectionPool._closeall
Beispiel #19
0
class PriorityQueue:
    def __init__(self):
        self._queue = []
        self._semaphore = Semaphore(0)

    def __len__(self):
        return len(self._queue)

    def push(self, request):
        heappush(self._queue, _PriorityQueueItem(request))
        self._semaphore.release()

    def pop(self):
        self._semaphore.acquire()
        item = heappop(self._queue)
        return item.request
Beispiel #20
0
class JobSharedLock(object):
    """
    Shared lock for jobs.
    Each job method can specify a lock which will be shared
    among all calls for that job and only one job can run at a time
    for this lock.
    """

    def __init__(self, queue, name):
        self.queue = queue
        self.name = name
        self.jobs = []
        self.semaphore = Semaphore()

    def add_job(self, job):
        self.jobs.append(job)

    def get_jobs(self):
        return self.jobs

    def remove_job(self, job):
        self.jobs.remove(job)

    def locked(self):
        return self.semaphore.locked()

    def acquire(self):
        return self.semaphore.acquire()

    def release(self):
        return self.semaphore.release()
Beispiel #21
0
class Puppet(object):
    def __init__(self):
        self.results = {}
        self.status = ("Idle", None)
        self.worker = None
        self.operator = None
        self.lock = Semaphore()

    def receive_info(self):
        while True:
            (info_id, info) = self.task_queue.get()
            if info_id > 0:
                self.results[info_id].set(info)
            elif info_id == -1:
                self.status = info

    def submit_task(self, tid, task_info):
        self.results[tid] = AsyncResult()
        self.lock.acquire()
        self.task_queue.put((tid, task_info))
        self.lock.release()

    def fetch_result(self, tid):
        res = self.results[tid].get()
        del self.results[tid]
        return res

    def hire_worker(self):
        tq_worker, self.task_queue = gipc.pipe(duplex=True)
        self.worker = gipc.start_process(target=process_worker,
                                         args=(tq_worker, ))
        self.operator = gevent.spawn(self.receive_info)

    def fire_worker(self):
        self.worker.terminate()
        self.operator.kill()

    def get_attr(self, attr):
        if attr == "cpu":
            return psutil.cpu_percent(interval=None)
        elif attr == "memory":
            return psutil.virtual_memory().percent
        else:
            return getattr(self, attr, "No such attribute")

    def current_tasks(self):
        return self.results.keys()
def test_renamed_label_refresh(db, default_account, thread, message, imapuid,
                               folder, mock_imapclient, monkeypatch):
    # Check that imapuids see their labels refreshed after running
    # the LabelRenameHandler.
    msg_uid = imapuid.msg_uid
    uid_dict = {msg_uid: GmailFlags((), ("stale label", ), ("23", ))}

    update_metadata(default_account.id, folder.id, folder.canonical_name,
                    uid_dict, db.session)

    new_flags = {
        msg_uid: {
            b"FLAGS": (b"\\Seen", ),
            b"X-GM-LABELS": (b"new label", ),
            b"MODSEQ": (23, ),
        }
    }
    mock_imapclient._data["[Gmail]/All mail"] = new_flags

    mock_imapclient.add_folder_data(folder.name, new_flags)

    monkeypatch.setattr(MockIMAPClient, "search", lambda x, y: [msg_uid])

    semaphore = Semaphore(value=1)

    rename_handler = LabelRenameHandler(default_account.id,
                                        default_account.namespace.id,
                                        "new label", semaphore)

    # Acquire the semaphore to check that LabelRenameHandlers block if
    # the semaphore is in-use.
    semaphore.acquire()
    rename_handler.start()

    gevent.sleep(0)  # yield to the handler

    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == "stale label"
    semaphore.release()
    rename_handler.join()

    db.session.refresh(imapuid)
    # Now check that the label got updated.
    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == "new label"
Beispiel #23
0
class Puppet(object):
    def __init__(self):
        self.results = {}
        self.status = ("Idle", None)
        self.worker = None
        self.operator = None
        self.lock = Semaphore()

    def receive_info(self):
        while True:
            (info_id, info) = self.task_queue.get()
            if info_id > 0:
                self.results[info_id].set(info)
            elif info_id == -1:
                self.status = info

    def submit_task(self, tid, task_info):
        self.results[tid] = AsyncResult()
        self.lock.acquire()
        self.task_queue.put((tid, task_info))
        self.lock.release()

    def fetch_result(self, tid):
        res = self.results[tid].get()
        del self.results[tid]
        return res

    def hire_worker(self):
        tq_worker, self.task_queue = gipc.pipe(duplex=True)
        self.worker = gipc.start_process(target=process_worker, args=(tq_worker,))
        self.operator = gevent.spawn(self.receive_info)

    def fire_worker(self):
        self.worker.terminate()
        self.operator.kill()

    def get_attr(self, attr):
        if attr == "cpu":
            return psutil.cpu_percent(interval=None)
        elif attr == "memory":
            return psutil.virtual_memory().percent
        else:
            return getattr(self, attr, "No such attribute")

    def current_tasks(self):
        return self.results.keys()
def test_renamed_label_refresh(db, default_account, thread, message,
                               imapuid, folder, mock_imapclient, monkeypatch):
    # Check that imapuids see their labels refreshed after running
    # the LabelRenameHandler.
    msg_uid = imapuid.msg_uid
    uid_dict = {msg_uid: GmailFlags((), ('stale label',), ('23',))}

    update_metadata(default_account.id, folder.id, folder.canonical_name,
                    uid_dict, db.session)

    new_flags = {msg_uid: {'FLAGS': ('\\Seen',), 'X-GM-LABELS': ('new label',),
                           'MODSEQ': ('23',)}}
    mock_imapclient._data['[Gmail]/All mail'] = new_flags

    mock_imapclient.add_folder_data(folder.name, new_flags)

    monkeypatch.setattr(MockIMAPClient, 'search',
                        lambda x, y: [msg_uid])

    semaphore = Semaphore(value=1)

    rename_handler = LabelRenameHandler(default_account.id,
                                        default_account.namespace.id,
                                        'new label', semaphore)

    # Acquire the semaphore to check that LabelRenameHandlers block if
    # the semaphore is in-use.
    semaphore.acquire()
    rename_handler.start()

    # Wait 10 secs and check that the data hasn't changed.
    gevent.sleep(10)

    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == 'stale label'
    semaphore.release()
    rename_handler.join()

    db.session.refresh(imapuid)
    # Now check that the label got updated.
    labels = list(imapuid.labels)
    assert len(labels) == 1
    assert labels[0].name == 'new label'
Beispiel #25
0
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):
    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests = 100, *args, **kwargs):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')
        self._lock = Semaphore(maximum_outstanding_requests)
        super(GeventedHTTPTransport, self).__init__(parsed_url, *args, **kwargs)

    def async_send(self, data, headers, success_cb, failure_cb):
        self._lock.acquire()
        return gevent.spawn(super(GeventedHTTPTransport, self).send, data, headers).link(lambda x: self._done(x, success_cb, failure_cb))

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.exception)
Beispiel #26
0
class DatagramServer(BaseServer):
    """A UDP server"""

    reuse_addr = DEFAULT_REUSE_ADDR

    def __init__(self, *args, **kwargs):
        # The raw (non-gevent) socket, if possible
        self._socket = None
        BaseServer.__init__(self, *args, **kwargs)
        from gevent.lock import Semaphore
        self._writelock = Semaphore()

    def init_socket(self):
        if not hasattr(self, 'socket'):
            # FIXME: clean up the socket lifetime
            # pylint:disable=attribute-defined-outside-init
            self.socket = self.get_listener(self.address, self.family)
            self.address = self.socket.getsockname()
        self._socket = self.socket
        try:
            self._socket = self._socket._sock
        except AttributeError:
            pass

    @classmethod
    def get_listener(cls, address, family=None):
        return _udp_socket(address, reuse_addr=cls.reuse_addr, family=family)

    def do_read(self):
        try:
            data, address = self._socket.recvfrom(8192)
        except _socket.error as err:
            if err.args[0] == EWOULDBLOCK:
                return
            raise
        return data, address

    def sendto(self, *args):
        self._writelock.acquire()
        try:
            self.socket.sendto(*args)
        finally:
            self._writelock.release()
Beispiel #27
0
class DatagramServer(BaseServer):
    """A UDP server"""

    reuse_addr = DEFAULT_REUSE_ADDR

    def __init__(self, *args, **kwargs):
        # The raw (non-gevent) socket, if possible
        self._socket = None
        BaseServer.__init__(self, *args, **kwargs)
        from gevent.lock import Semaphore
        self._writelock = Semaphore()

    def init_socket(self):
        if not hasattr(self, 'socket'):
            # FIXME: clean up the socket lifetime
            # pylint:disable=attribute-defined-outside-init
            self.socket = self.get_listener(self.address, self.family)
            self.address = self.socket.getsockname()
        self._socket = self.socket
        try:
            self._socket = self._socket._sock
        except AttributeError:
            pass

    @classmethod
    def get_listener(cls, address, family=None):
        return _udp_socket(address, reuse_addr=cls.reuse_addr, family=family)

    def do_read(self):
        try:
            data, address = self._socket.recvfrom(8192)
        except SocketError as err:
            if err.args[0] == EWOULDBLOCK:
                return
            raise
        return data, address

    def sendto(self, *args):
        self._writelock.acquire()
        try:
            self.socket.sendto(*args)
        finally:
            self._writelock.release()
class ProxyGuard(object):
    """
    Helper class that can be used to protect an proxy object that represents a remote resource.
    It is possible to specify the maximum number of parallel calls of any method.
    This is helpful in the case that the resource limits the number of parallel connections.
    """
    def __init__(self,
                 proxy: Any,
                 max_parallel_access=1,
                 attr_name: Optional[str] = None,
                 semaphore: Optional[Semaphore] = None) -> None:
        self._attr_name = attr_name
        self._proxy = proxy
        self._max_parallel_access = max_parallel_access

        if semaphore is not None:
            self._semaphore = semaphore
        else:
            self._semaphore = Semaphore(value=max_parallel_access)

    def __getattr__(self, name):
        return ProxyGuard(self._proxy,
                          max_parallel_access=self._max_parallel_access,
                          attr_name=name,
                          semaphore=self._semaphore)

    def __call__(self, *args, **kwargs) -> Any:
        self._semaphore.acquire(blocking=True)

        ex = None
        result = None

        try:
            result = getattr(self._proxy, self._attr_name)(*args, **kwargs)
        except BaseException as base_exception:
            ex = base_exception
        finally:
            self._semaphore.release()

        if ex is not None:
            raise ex

        return result
Beispiel #29
0
class DatagramServer(BaseServer):
    """A UDP server"""

    reuse_addr = DEFAULT_REUSE_ADDR

    def __init__(self, *args, **kwargs):
        BaseServer.__init__(self, *args, **kwargs)
        from gevent.lock import Semaphore
        self._writelock = Semaphore()

    def init_socket(self):
        if not hasattr(self, 'socket'):
            self.socket = self.get_listener(self.address, self.family)
            self.address = self.socket.getsockname()
        self._socket = self.socket
        try:
            if sys.version_info[0] == 2:
                self._socket = self._socket._sock
            else:
                self._socket = super(socket, self._socket)
        except AttributeError:
            pass

    @classmethod
    def get_listener(self, address, family=None):
        return _udp_socket(address, reuse_addr=self.reuse_addr, family=family)

    def do_read(self):
        try:
            data, address = self._socket.recvfrom(8192)
        except _socket.error as err:
            if err.args[0] == EWOULDBLOCK:
                return
            raise
        return data, address

    def sendto(self, *args):
        self._writelock.acquire()
        try:
            self.socket.sendto(*args)
        finally:
            self._writelock.release()
Beispiel #30
0
class Hub:
    def __init__(self, num):
        self.need = 0
        self.block_queue = Queue(maxsize=1)
        self.total_player = []
        for n in range(num):
            self.total_player.append(n)
        self.semaphore = Semaphore()

    def _get_player(self, num):
        if len(self.total_player) < num:
            # block here
            self.need = num
            print "_get_player set need=%s" % self.need
            self.block_queue.get(block=True)

        assert len(self.total_player) >= num, (
            "_get_player error total_player=%s num=%s" %
            (len(self.total_player), num))
        size = len(self.total_player)
        player_list = self.total_player[size - num:]
        self.total_player = self.total_player[:size - num]

        return player_list

    def acquire_player(self, num):
        # lock
        self.semaphore.acquire()
        print "acquire_player num=%s" % num
        player_list = self._get_player(num)
        # unlock
        self.semaphore.release()
        return player_list

    def release_player(self, player_list):
        self.total_player.extend(player_list)
        print "release_player need=%s len(self.total_player)=%s" % (
            self.need, len(self.total_player))
        if self.need == 0 or self.need > len(self.total_player):
            return
        self.need = 0
        self.block_queue.put(1, block=True)
Beispiel #31
0
class TaskRunner(object):
    def __init__(self, name, task):
        self.name = name
        self.task = task
        self.lock = Semaphore(task.max_concurrent)

    def process(self, job):
        log.info('[%s] Running job %s...', job['id'], self.name)
        start = time.time()

        try:
            self.task(*job['args'], **job['kwargs'])
            if self.task.buffer_time:
                time.sleep(self.task.buffer_time)
        except:
            log.exception('[%s] Failed in %ss', job['id'], time.time() - start)

        log.info('[%s] Completed in %ss', job['id'], time.time() - start)

    def run(self, job):
        lock = None
        if self.task.global_lock:
            lock = rdb.lock('{}:{}'.format(
                self.task.name,
                self.task.global_lock(
                    *job['args'],
                    **job['kwargs']
                )
            ))
            lock.acquire()

        if self.task.max_concurrent:
            self.lock.acquire()

        self.process(job)

        if lock:
            lock.release()

        if self.task.max_concurrent:
            self.lock.release()
Beispiel #32
0
class DatagramServer(BaseServer):
    """A UDP server"""

    reuse_addr = DEFAULT_REUSE_ADDR

    def __init__(self, *args, **kwargs):
        BaseServer.__init__(self, *args, **kwargs)
        from gevent.lock import Semaphore
        self._writelock = Semaphore()

    def init_socket(self):
        if not hasattr(self, 'socket'):
            self.socket = self.get_listener(self.address, self.family)
            self.address = self.socket.getsockname()
        self._socket = self.socket
        try:
            self._socket = self._socket._sock
        except AttributeError:
            pass

    @classmethod
    def get_listener(self, address, family=None):
        return _udp_socket(address, reuse_addr=self.reuse_addr, family=family)

    def do_read(self):
        try:
            data, address = self._socket.recvfrom(8192)
        except _socket.error:
            err = sys.exc_info()[1]
            if err.args[0] == EWOULDBLOCK:
                return
            raise
        return data, address

    def sendto(self, *args):
        self._writelock.acquire()
        try:
            self.socket.sendto(*args)
        finally:
            self._writelock.release()
Beispiel #33
0
class LimitEngine(EngineBase):
    def __init__(self, provider, n, **kwargs):
        super(LimitEngine, self).__init__(provider, **kwargs)
        self.conf = self.provider.configurations()[0]
        self.n = n
        self.services = []
        self.num_unscheduled = -1
        self.le_lock = Semaphore()

    def before_eval(self):
        self.num_unscheduled = len(self.dag)

    def after_eval(self):
        for s in self.services:
            self.provider.stop_service(s)
        self.services = []

    def which_service(self, task):
        self.num_unscheduled -= 1
        for s in self.services:
            if len(s.tasks) == 0:
                return s
        if len(self.services) < self.n:
            s = self.provider.start_service(len(self.services) + 1, self.conf)
            self.services.append(s)
        else:
            s = min(self.services, key=lambda x: len(x.tasks))
        return s

    def after_task(self, task, service):
        self.le_lock.acquire()
        if self.num_unscheduled == 0 and len(service.tasks) == 0:
            self.provider.stop_service(service)
            for s in self.services:
                if len(s.tasks) == 0:
                    self.provider.stop_service(s)
        self.le_lock.release()

    def current_services(self):
        return self.services
Beispiel #34
0
class BlockingDeque(deque):

    def __init__(self, *args, **kwargs):
        super(BlockingDeque, self).__init__(*args, **kwargs)
        self.sema = Semaphore(len(self))

    def append(self, *args, **kwargs):
        ret = super(BlockingDeque, self).append(*args, **kwargs)
        self.sema.release()
        return ret

    def appendleft(self, *args, **kwargs):
        ret = super(BlockingDeque, self).appendleft(*args, **kwargs)
        self.sema.release()
        return ret

    def clear(self, *args, **kwargs):
        ret = super(BlockingDeque, self).clear(*args, **kwargs)
        while not self.sema.locked():
            self.sema.acquire(blocking=False)
        return ret

    def extend(self, *args, **kwargs):
        pre_n = len(self)
        ret = super(BlockingDeque, self).extend(*args, **kwargs)
        post_n = len(self)
        for i in range(pre_n, post_n):
            self.sema.release()
        return ret

    def extendleft(self, *args, **kwargs):
        pre_n = len(self)
        ret = super(BlockingDeque, self).extendleft(*args, **kwargs)
        post_n = len(self)
        for i in range(pre_n, post_n):
            self.sema.release()
        return ret

    def pop(self, *args, **kwargs):
        self.sema.acquire()
        return super(BlockingDeque, self).pop(*args, **kwargs)

    def popleft(self, *args, **kwargs):
        self.sema.acquire()
        return super(BlockingDeque, self).popleft(*args, **kwargs)

    def remove(self, *args, **kwargs):
        ret = super(BlockingDeque, self).remove(*args, **kwargs)
        self.sema.acquire()
        return ret
Beispiel #35
0
class GeventedHTTPTransport(AsyncTransport, HTTPTransport):

    scheme = ['gevent+http', 'gevent+https']

    def __init__(self, parsed_url, maximum_outstanding_requests=100,
                 oneway=False):
        if not has_gevent:
            raise ImportError('GeventedHTTPTransport requires gevent.')
        self._lock = Semaphore(maximum_outstanding_requests)
        self._oneway = oneway == 'true'

        super(GeventedHTTPTransport, self).__init__(parsed_url)

        # remove the gevent+ from the protocol, as it is not a real protocol
        self._url = self._url.split('+', 1)[-1]

    def async_send(self, data, headers, success_cb, failure_cb):
        """
        Spawn an async request to a remote webserver.
        """
        if not self._oneway:
            self._lock.acquire()
            return gevent.spawn(
                super(GeventedHTTPTransport, self).send, data, headers
            ).link(lambda x: self._done(x, success_cb, failure_cb))
        else:
            req = urllib2.Request(self._url, headers=headers)
            return urlopen(url=req,
                           data=data,
                           timeout=self.timeout,
                           verify_ssl=self.verify_ssl,
                           ca_certs=self.ca_certs)

    def _done(self, greenlet, success_cb, failure_cb, *args):
        self._lock.release()
        if greenlet.successful():
            success_cb()
        else:
            failure_cb(greenlet.exception)
Beispiel #36
0
class BlockingDeque(deque):
    def __init__(self, *args, **kwargs):
        super(BlockingDeque, self).__init__(*args, **kwargs)
        self.sema = Semaphore(len(self))

    def append(self, *args, **kwargs):
        ret = super(BlockingDeque, self).append(*args, **kwargs)
        self.sema.release()
        return ret

    def appendleft(self, *args, **kwargs):
        ret = super(BlockingDeque, self).appendleft(*args, **kwargs)
        self.sema.release()
        return ret

    def clear(self, *args, **kwargs):
        ret = super(BlockingDeque, self).clear(*args, **kwargs)
        while not self.sema.locked():
            self.sema.acquire(blocking=False)
        return ret

    def extend(self, *args, **kwargs):
        pre_n = len(self)
        ret = super(BlockingDeque, self).extend(*args, **kwargs)
        post_n = len(self)
        for i in range(pre_n, post_n):
            self.sema.release()
        return ret

    def extendleft(self, *args, **kwargs):
        pre_n = len(self)
        ret = super(BlockingDeque, self).extendleft(*args, **kwargs)
        post_n = len(self)
        for i in range(pre_n, post_n):
            self.sema.release()
        return ret

    def pop(self, *args, **kwargs):
        self.sema.acquire()
        return super(BlockingDeque, self).pop(*args, **kwargs)

    def popleft(self, *args, **kwargs):
        self.sema.acquire()
        return super(BlockingDeque, self).popleft(*args, **kwargs)

    def remove(self, *args, **kwargs):
        ret = super(BlockingDeque, self).remove(*args, **kwargs)
        self.sema.acquire()
        return ret
Beispiel #37
0
    def test_fair_or_hangs(self):
        # If the lock isn't fair, this hangs, spinning between
        # the last two greenlets.
        # See https://github.com/gevent/gevent/issues/1487
        sem = Semaphore()
        should_quit = []

        keep_going1 = FirstG.spawn(acquire_then_spawn, sem, should_quit)
        keep_going2 = FirstG.spawn(acquire_then_spawn, sem, should_quit)
        exiting = LastG.spawn(acquire_then_exit, sem, should_quit)

        with self.assertRaises(gevent.exceptions.LoopExit):
            gevent.joinall([keep_going1, keep_going2, exiting])

        self.assertTrue(exiting.dead, exiting)
        self.assertTrue(keep_going2.dead, keep_going2)
        self.assertFalse(keep_going1.dead, keep_going1)

        sem.release()
        keep_going1.kill()
        keep_going2.kill()
        exiting.kill()

        gevent.idle()
Beispiel #38
0
class RollingNumber(object):
    _last_idx = 0

    @property
    def now(self):
        return int(time.time() * 1000)

    def __init__(self, window=10 * 1000, interval=1000):
        self.window = window
        self.interval = interval
        self.bucket_num = window / self.interval
        self._new_bucket_lock = Semaphore()

        self.reset(self.now)

    def reset(self, start=None):
        self._buckets = deque([0], maxlen=self.bucket_num)
        self._start = start or self.now

    def get_current_bucket_index(self):
        if not self._new_bucket_lock.acquire(0):
            return self._last_idx

        now = self.now
        elapsed = now - self._start

        if elapsed > self.interval:
            if elapsed > self.window:
                self.reset(now)
                self._new_bucket_lock.release()
                return self._last_idx

            t = elapsed - self.interval
            while t > 0:
                self._buckets.appendleft(0)
                t1, t = t, t - self.interval
                if t < 0:
                    self._start = now - t1
            self._new_bucket_lock.release()
            return self._last_idx

        self._new_bucket_lock.release()
        return self._last_idx

    @property
    def current_bucket_count(self):
        idx = self.get_current_bucket_index()
        return self._buckets[idx]

    def inc(self):
        i = self.get_current_bucket_index()
        self._buckets[i] += 1

    @property
    def sum(self):
        return sum(self._buckets)
class AnalyticsDiscovery(gevent.Greenlet):

    def _sandesh_connection_info_update(self, status, message):

        new_conn_state = getattr(ConnectionStatus, status)
        ConnectionState.update(conn_type = ConnectionType.ZOOKEEPER,
                name = self._svc_name, status = new_conn_state,
                message = message,
                server_addrs = self._zk_server.split(','))

        if (self._conn_state and self._conn_state != ConnectionStatus.DOWN and
                new_conn_state == ConnectionStatus.DOWN):
            msg = 'Connection to Zookeeper down: %s' %(message)
            self._logger.error(msg)
        if (self._conn_state and self._conn_state != new_conn_state and
                new_conn_state == ConnectionStatus.UP):
            msg = 'Connection to Zookeeper ESTABLISHED'
            self._logger.error(msg)

        self._conn_state = new_conn_state
    # end _sandesh_connection_info_update

    def _zk_listen(self, state):
        self._logger.error("Analytics Discovery listen %s" % str(state))
        if state == KazooState.CONNECTED:
            self._sandesh_connection_info_update(status='UP', message='')
            self._logger.error("Analytics Discovery to publish %s" % str(self._pubinfo))
            self._reconnect = True
        elif state == KazooState.LOST:
            self._logger.error("Analytics Discovery connection LOST")
            # Lost the session with ZooKeeper Server
            # Best of option we have is to exit the process and restart all 
            # over again
            self._sandesh_connection_info_update(status='DOWN',
                                      message='Connection to Zookeeper lost')
            os._exit(2)
        elif state == KazooState.SUSPENDED:
            self._logger.error("Analytics Discovery connection SUSPENDED")
            # Update connection info
            self._sandesh_connection_info_update(status='INIT',
                message = 'Connection to zookeeper lost. Retrying')

    def _zk_datawatch(self, watcher, child, data, stat, event="unknown"):
        self._logger.error(\
                "Analytics Discovery %s ChildData : child %s, data %s, event %s" % \
                (watcher, child, data, event))
        if data:
            data_dict = json.loads(data)
            self._wchildren[watcher][child] = OrderedDict(sorted(data_dict.items()))
        else:
            if child in self._wchildren[watcher]:
                del self._wchildren[watcher][child]
        if self._watchers[watcher]:
            self._pendingcb.add(watcher)

    def _zk_watcher(self, watcher, children):
        self._logger.error("Analytics Discovery Children %s" % children)
        self._reconnect = True

    def __init__(self, logger, zkservers, svc_name, inst,
                watchers={}, zpostfix="", freq=10):
        gevent.Greenlet.__init__(self)
        self._svc_name = svc_name
        self._inst = inst
        self._zk_server = zkservers
        # initialize logging and other stuff
        if logger is None:
            logging.basicConfig()
            self._logger = logging
        else:
            self._logger = logger
        self._conn_state = None
        self._sandesh_connection_info_update(status='INIT', message='')
        self._zkservers = zkservers
        self._zk = None
        self._pubinfo = None
        self._publock = Semaphore()
        self._watchers = watchers
        self._wchildren = {}
        self._pendingcb = set()
        self._zpostfix = zpostfix
        self._basepath = "/analytics-discovery-" + self._zpostfix
        self._reconnect = None
        self._freq = freq

    def publish(self, pubinfo):

        # This function can be called concurrently by the main AlarmDiscovery
        # processing loop as well as by clients.
        # It is NOT re-entrant
        self._publock.acquire()

        self._pubinfo = pubinfo
        if self._conn_state == ConnectionStatus.UP:
            try:
                self._logger.error("ensure %s" % (self._basepath + "/" + self._svc_name))
                self._logger.error("zk state %s (%s)" % (self._zk.state, self._zk.client_state))
                self._zk.ensure_path(self._basepath + "/" + self._svc_name)
                self._logger.error("check for %s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst))
                if pubinfo is not None:
                    if self._zk.exists("%s/%s/%s" % \
                            (self._basepath, self._svc_name, self._inst)):
                        self._zk.set("%s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst),
                                self._pubinfo)
                    else:
                        self._zk.create("%s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst),
                                self._pubinfo, ephemeral=True)
                else:
                    if self._zk.exists("%s/%s/%s" % \
                            (self._basepath, self._svc_name, self._inst)):
                        self._logger.error("withdrawing published info!")
                        self._zk.delete("%s/%s/%s" % \
                                (self._basepath, self._svc_name, self._inst))

            except Exception as ex:
                template = "Exception {0} in AnalyticsDiscovery publish. Args:\n{1!r}"
                messag = template.format(type(ex).__name__, ex.args)
                self._logger.error("%s : traceback %s for %s info %s" % \
                        (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo)))
                self._sandesh_connection_info_update(status='DOWN', message='')
                self._reconnect = True
        else:
            self._logger.error("Analytics Discovery cannot publish while down")
        self._publock.release()

    def _run(self):
        while True:
            self._logger.error("Analytics Discovery zk start")
            self._zk = KazooClient(hosts=self._zkservers)
            self._zk.add_listener(self._zk_listen)
            try:
                self._zk.start()
                while self._conn_state != ConnectionStatus.UP:
                    gevent.sleep(1)
                break
            except Exception as e:
                # Update connection info
                self._sandesh_connection_info_update(status='DOWN',
                                                     message=str(e))
                self._zk.remove_listener(self._zk_listen)
                try:
                    self._zk.stop()
                    self._zk.close()
                except Exception as ex:
                    template = "Exception {0} in AnalyticsDiscovery zk stop/close. Args:\n{1!r}"
                    messag = template.format(type(ex).__name__, ex.args)
                    self._logger.error("%s : traceback %s for %s" % \
                        (messag, traceback.format_exc(), self._svc_name))
                finally:
                    self._zk = None
                gevent.sleep(1)

        try:
            # Update connection info
            self._sandesh_connection_info_update(status='UP', message='')
            self._reconnect = False
            # Done connecting to ZooKeeper

            for wk in self._watchers.keys():
                self._zk.ensure_path(self._basepath + "/" + wk)
                self._wchildren[wk] = {}
                self._zk.ChildrenWatch(self._basepath + "/" + wk,
                        partial(self._zk_watcher, wk))

            # Trigger the initial publish
            self._reconnect = True

            while True:
                try:
                    if not self._reconnect:
                        pending_list = list(self._pendingcb)
                        self._pendingcb = set()
                        for wk in pending_list:
                            if self._watchers[wk]:
                                self._watchers[wk](\
                                        sorted(self._wchildren[wk].values()))

                    # If a reconnect happens during processing, don't lose it
                    while self._reconnect:
                        self._logger.error("Analytics Discovery %s reconnect" \
                                % self._svc_name)
                        self._reconnect = False
                        self._pendingcb = set()
                        self.publish(self._pubinfo)

                        for wk in self._watchers.keys():
                            self._zk.ensure_path(self._basepath + "/" + wk)
                            children = self._zk.get_children(self._basepath + "/" + wk)

                            old_children = set(self._wchildren[wk].keys())
                            new_children = set(children)

                            # Remove contents for the children who are gone
                            # (DO NOT remove the watch)
                            for elem in old_children - new_children:
                                 del self._wchildren[wk][elem]

                            # Overwrite existing children, or create new ones
                            for elem in new_children:
                                # Create a watch for new children
                                if elem not in self._wchildren[wk]:
                                    self._zk.DataWatch(self._basepath + "/" + \
                                            wk + "/" + elem,
                                            partial(self._zk_datawatch, wk, elem))

                                data_str, _ = self._zk.get(\
                                        self._basepath + "/" + wk + "/" + elem)
                                data_dict = json.loads(data_str)
                                self._wchildren[wk][elem] = \
                                        OrderedDict(sorted(data_dict.items()))

                                self._logger.error(\
                                    "Analytics Discovery %s ChildData : child %s, data %s, event %s" % \
                                    (wk, elem, self._wchildren[wk][elem], "GET"))
                            if self._watchers[wk]:
                                self._watchers[wk](sorted(self._wchildren[wk].values()))

                    gevent.sleep(self._freq)
                except gevent.GreenletExit:
                    self._logger.error("Exiting AnalyticsDiscovery for %s" % \
                            self._svc_name)
                    self._zk.remove_listener(self._zk_listen)
                    gevent.sleep(1)
                    try:
                        self._zk.stop()
                    except:
                        self._logger.error("Stopping kazooclient failed")
                    else:
                        self._logger.error("Stopping kazooclient successful")
                    try:
                        self._zk.close()
                    except:
                        self._logger.error("Closing kazooclient failed")
                    else:
                        self._logger.error("Closing kazooclient successful")
                    break

                except Exception as ex:
                    template = "Exception {0} in AnalyticsDiscovery reconnect. Args:\n{1!r}"
                    messag = template.format(type(ex).__name__, ex.args)
                    self._logger.error("%s : traceback %s for %s info %s" % \
                        (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo)))
                    self._reconnect = True

        except Exception as ex:
            template = "Exception {0} in AnalyticsDiscovery run. Args:\n{1!r}"
            messag = template.format(type(ex).__name__, ex.args)
            self._logger.error("%s : traceback %s for %s info %s" % \
                    (messag, traceback.format_exc(), self._svc_name, str(self._pubinfo)))
            raise SystemExit
Beispiel #40
0
class ThreadPool(object):

    def __init__(self, maxsize, hub=None):
        if hub is None:
            hub = get_hub()
        self.hub = hub
        self._maxsize = 0
        self.manager = None
        self.pid = os.getpid()
        self.fork_watcher = hub.loop.fork(ref=False)
        self._init(maxsize)

    def _set_maxsize(self, maxsize):
        if not isinstance(maxsize, integer_types):
            raise TypeError('maxsize must be integer: %r' % (maxsize, ))
        if maxsize < 0:
            raise ValueError('maxsize must not be negative: %r' % (maxsize, ))
        difference = maxsize - self._maxsize
        self._semaphore.counter += difference
        self._maxsize = maxsize
        self.adjust()
        # make sure all currently blocking spawn() start unlocking if maxsize increased
        self._semaphore._start_notify()

    def _get_maxsize(self):
        return self._maxsize

    maxsize = property(_get_maxsize, _set_maxsize)

    def __repr__(self):
        return '<%s at 0x%x %s/%s/%s>' % (self.__class__.__name__, id(self), len(self), self.size, self.maxsize)

    def __len__(self):
        # XXX just do unfinished_tasks property
        return self.task_queue.unfinished_tasks

    def _get_size(self):
        return self._size

    def _set_size(self, size):
        if size < 0:
            raise ValueError('Size of the pool cannot be negative: %r' % (size, ))
        if size > self._maxsize:
            raise ValueError('Size of the pool cannot be bigger than maxsize: %r > %r' % (size, self._maxsize))
        if self.manager:
            self.manager.kill()
        while self._size < size:
            self._add_thread()
        delay = 0.0001
        while self._size > size:
            while self._size - size > self.task_queue.unfinished_tasks:
                self.task_queue.put(None)
            if getcurrent() is self.hub:
                break
            sleep(delay)
            delay = min(delay * 2, .05)
        if self._size:
            self.fork_watcher.start(self._on_fork)
        else:
            self.fork_watcher.stop()

    size = property(_get_size, _set_size)

    def _init(self, maxsize):
        self._size = 0
        self._semaphore = Semaphore(1)
        self._lock = Lock()
        self.task_queue = Queue()
        self._set_maxsize(maxsize)

    def _on_fork(self):
        # fork() only leaves one thread; also screws up locks;
        # let's re-create locks and threads
        pid = os.getpid()
        if pid != self.pid:
            self.pid = pid
            # Do not mix fork() and threads; since fork() only copies one thread
            # all objects referenced by other threads has refcount that will never
            # go down to 0.
            self._init(self._maxsize)

    def join(self):
        delay = 0.0005
        while self.task_queue.unfinished_tasks > 0:
            sleep(delay)
            delay = min(delay * 2, .05)

    def kill(self):
        self.size = 0

    def _adjust_step(self):
        # if there is a possibility & necessity for adding a thread, do it
        while self._size < self._maxsize and self.task_queue.unfinished_tasks > self._size:
            self._add_thread()
        # while the number of threads is more than maxsize, kill one
        # we do not check what's already in task_queue - it could be all Nones
        while self._size - self._maxsize > self.task_queue.unfinished_tasks:
            self.task_queue.put(None)
        if self._size:
            self.fork_watcher.start(self._on_fork)
        else:
            self.fork_watcher.stop()

    def _adjust_wait(self):
        delay = 0.0001
        while True:
            self._adjust_step()
            if self._size <= self._maxsize:
                return
            sleep(delay)
            delay = min(delay * 2, .05)

    def adjust(self):
        self._adjust_step()
        if not self.manager and self._size > self._maxsize:
            # might need to feed more Nones into the pool
            self.manager = Greenlet.spawn(self._adjust_wait)

    def _add_thread(self):
        with self._lock:
            self._size += 1
        try:
            start_new_thread(self._worker, ())
        except:
            with self._lock:
                self._size -= 1
            raise

    def spawn(self, func, *args, **kwargs):
        while True:
            semaphore = self._semaphore
            semaphore.acquire()
            if semaphore is self._semaphore:
                break
        try:
            task_queue = self.task_queue
            result = AsyncResult()
            thread_result = ThreadResult(result, hub=self.hub)
            task_queue.put((func, args, kwargs, thread_result))
            self.adjust()
            # rawlink() must be the last call
            result.rawlink(lambda *args: self._semaphore.release())
            # XXX this _semaphore.release() is competing for order with get()
            # XXX this is not good, just make ThreadResult release the semaphore before doing anything else
        except:
            semaphore.release()
            raise
        return result

    def _decrease_size(self):
        if sys is None:
            return
        _lock = getattr(self, '_lock', None)
        if _lock is not None:
            with _lock:
                self._size -= 1

    def _worker(self):
        need_decrease = True
        try:
            while True:
                task_queue = self.task_queue
                task = task_queue.get()
                try:
                    if task is None:
                        need_decrease = False
                        self._decrease_size()
                        # we want first to decrease size, then decrease unfinished_tasks
                        # otherwise, _adjust might think there's one more idle thread that
                        # needs to be killed
                        return
                    func, args, kwargs, thread_result = task
                    try:
                        value = func(*args, **kwargs)
                    except:
                        exc_info = getattr(sys, 'exc_info', None)
                        if exc_info is None:
                            return
                        thread_result.handle_error((self, func), exc_info())
                    else:
                        if sys is None:
                            return
                        thread_result.set(value)
                        del value
                    finally:
                        del func, args, kwargs, thread_result, task
                finally:
                    if sys is None:
                        return
                    task_queue.task_done()
        finally:
            if need_decrease:
                self._decrease_size()

    # XXX apply() should re-raise error by default
    # XXX because that's what builtin apply does
    # XXX check gevent.pool.Pool.apply and multiprocessing.Pool.apply
    def apply_e(self, expected_errors, function, args=None, kwargs=None):
        if args is None:
            args = ()
        if kwargs is None:
            kwargs = {}
        success, result = self.spawn(wrap_errors, expected_errors, function, args, kwargs).get()
        if success:
            return result
        result = [result]
        raise result.pop()

    def apply(self, func, args=None, kwds=None):
        """Equivalent of the apply() builtin function. It blocks till the result is ready."""
        if args is None:
            args = ()
        if kwds is None:
            kwds = {}
        return self.spawn(func, *args, **kwds).get()

    def apply_cb(self, func, args=None, kwds=None, callback=None):
        result = self.apply(func, args, kwds)
        if callback is not None:
            callback(result)
        return result

    def apply_async(self, func, args=None, kwds=None, callback=None):
        """A variant of the apply() method which returns a Greenlet object.

        If callback is specified then it should be a callable which accepts a single argument. When the result becomes ready
        callback is applied to it (unless the call failed)."""
        if args is None:
            args = ()
        if kwds is None:
            kwds = {}
        return Greenlet.spawn(self.apply_cb, func, args, kwds, callback)

    def map(self, func, iterable):
        return list(self.imap(func, iterable))

    def map_cb(self, func, iterable, callback=None):
        result = self.map(func, iterable)
        if callback is not None:
            callback(result)
        return result

    def map_async(self, func, iterable, callback=None):
        """
        A variant of the map() method which returns a Greenlet object.

        If callback is specified then it should be a callable which accepts a
        single argument.
        """
        return Greenlet.spawn(self.map_cb, func, iterable, callback)

    def imap(self, func, iterable):
        """An equivalent of itertools.imap()"""
        return IMap.spawn(func, iterable, spawn=self.spawn)

    def imap_unordered(self, func, iterable):
        """The same as imap() except that the ordering of the results from the
        returned iterator should be considered in arbitrary order."""
        return IMapUnordered.spawn(func, iterable, spawn=self.spawn)
Beispiel #41
0
class Pool(Group):

    def __init__(self, size=None, greenlet_class=None):
        """
        Create a new pool.

        A pool is like a group, but the maximum number of members
        is governed by the *size* parameter.

        :keyword int size: If given, this non-negative integer is the
            maximum count of active greenlets that will be allowed in
            this pool. A few values have special significance:

            * ``None`` (the default) places no limit on the number of
              greenlets. This is useful when you need to track, but not limit,
              greenlets, as with :class:`gevent.pywsgi.WSGIServer`. A :class:`Group`
              may be a more efficient way to achieve the same effect.
            * ``0`` creates a pool that can never have any active greenlets. Attempting
              to spawn in this pool will block forever. This is only useful
              if an application uses :meth:`wait_available` with a timeout and checks
              :meth:`free_count` before attempting to spawn.
        """
        if size is not None and size < 0:
            raise ValueError('size must not be negative: %r' % (size, ))
        Group.__init__(self)
        self.size = size
        if greenlet_class is not None:
            self.greenlet_class = greenlet_class
        if size is None:
            self._semaphore = DummySemaphore()
        else:
            self._semaphore = Semaphore(size)

    def wait_available(self, timeout=None):
        """
        Wait until it's possible to spawn a greenlet in this pool.

        :param float timeout: If given, only wait the specified number
            of seconds.

        .. warning:: If the pool was initialized with a size of 0, this
           method will block forever unless a timeout is given.

        :return: A number indicating how many new greenlets can be put into
           the pool without blocking.

        .. versionchanged:: 1.1a3
            Added the ``timeout`` parameter.
        """
        return self._semaphore.wait(timeout=timeout)

    def full(self):
        """
        Return a boolean indicating whether this pool has any room for
        members. (True if it does, False if it doesn't.)
        """
        return self.free_count() <= 0

    def free_count(self):
        """
        Return a number indicating *approximately* how many more members
        can be added to this pool.
        """
        if self.size is None:
            return 1
        return max(0, self.size - len(self))

    def add(self, greenlet):
        """
        Begin tracking the given greenlet, blocking until space is available.

        .. seealso:: :meth:`Group.add`
        """
        self._semaphore.acquire()
        try:
            Group.add(self, greenlet)
        except:
            self._semaphore.release()
            raise

    def _discard(self, greenlet):
        Group._discard(self, greenlet)
        self._semaphore.release()
Beispiel #42
0
class VncKombuClientBase(object):
    def _update_sandesh_status(self, status, msg=''):
        ConnectionState.update(conn_type=ConnectionType.DATABASE,
            name='RabbitMQ', status=status, message=msg,
            server_addrs=["%s:%s" % (self._rabbit_ip, self._rabbit_port)])
    # end _update_sandesh_status

    def publish(self, message):
        self._publish_queue.put(message)
    # end publish

    def __init__(self, rabbit_ip, rabbit_port, rabbit_user, rabbit_password,
                 rabbit_vhost, rabbit_ha_mode, q_name, subscribe_cb, logger):
        self._rabbit_ip = rabbit_ip
        self._rabbit_port = rabbit_port
        self._rabbit_user = rabbit_user
        self._rabbit_password = rabbit_password
        self._rabbit_vhost = rabbit_vhost
        self._subscribe_cb = subscribe_cb
        self._logger = logger
        self._publish_queue = Queue()
        self._conn_lock = Semaphore()

        self.obj_upd_exchange = kombu.Exchange('vnc_config.object-update', 'fanout',
                                               durable=False)

    def num_pending_messages(self):
        return self._publish_queue.qsize()
    # end num_pending_messages

    def prepare_to_consume(self):
        # override this method
        return

    def _reconnect(self, delete_old_q=False):
        if self._conn_lock.locked():
            # either connection-monitor or publisher should have taken
            # the lock. The one who acquired the lock would re-establish
            # the connection and releases the lock, so the other one can 
            # just wait on the lock, till it gets released
            self._conn_lock.wait()
            return

        self._conn_lock.acquire()

        msg = "RabbitMQ connection down"
        self._logger(msg, level=SandeshLevel.SYS_ERR)
        self._update_sandesh_status(ConnectionStatus.DOWN)
        self._conn_state = ConnectionStatus.DOWN

        self._conn.close()

        self._conn.ensure_connection()
        self._conn.connect()

        self._update_sandesh_status(ConnectionStatus.UP)
        self._conn_state = ConnectionStatus.UP
        msg = 'RabbitMQ connection ESTABLISHED %s' % repr(self._conn)
        self._logger(msg, level=SandeshLevel.SYS_NOTICE)

        self._channel = self._conn.channel()
        if delete_old_q:
            # delete the old queue in first-connect context
            # as db-resync would have caught up with history.
            try:
                bound_q = self._update_queue_obj(self._channel)
                bound_q.delete()
            except Exception as e:
                msg = 'Unable to delete the old ampq queue: %s' %(str(e))
                self._logger(msg, level=SandeshLevel.SYS_ERR)

        self._consumer = kombu.Consumer(self._channel,
                                       queues=self._update_queue_obj,
                                       callbacks=[self._subscribe])
        self._producer = kombu.Producer(self._channel, exchange=self.obj_upd_exchange)

        self._conn_lock.release()
    # end _reconnect

    def _connection_watch(self):
        self.prepare_to_consume()
        while True:
            try:
                self._consumer.consume()
                self._conn.drain_events()
            except self._conn.connection_errors + self._conn.channel_errors as e:
                self._reconnect()
    # end _connection_watch

    def _publisher(self):
        message = None
        while True:
            try:
                if not message:
                    # earlier was sent fine, dequeue one more
                    message = self._publish_queue.get()

                while True:
                    try:
                        self._producer.publish(message)
                        message = None
                        break
                    except self._conn.connection_errors + self._conn.channel_errors as e:
                        self._reconnect()
            except Exception as e:
                log_str = "Unknown exception in _publisher greenlet" + str(e)
                self._logger(log_str, level=SandeshLevel.SYS_ERR)
    # end _publisher

    def _subscribe(self, body, message):
        try:
            self._subscribe_cb(body)
        finally:
            message.ack()


    def _start(self):
        self._reconnect(delete_old_q=True)

        self._publisher_greenlet = gevent.spawn(self._publisher)
        self._connection_monitor_greenlet = gevent.spawn(self._connection_watch)

    def shutdown(self):
        self._publisher_greenlet.kill()
        self._connection_monitor_greenlet.kill()
        self._producer.close()
        self._consumer.close()
        self._conn.close()
Beispiel #43
0
class Rotkehlchen():
    def __init__(self, args: argparse.Namespace) -> None:
        """Initialize the Rotkehlchen object

        May Raise:
        - SystemPermissionError if the given data directory's permissions
        are not correct.
        """
        self.lock = Semaphore()
        self.lock.acquire()

        # Can also be None after unlock if premium credentials did not
        # authenticate or premium server temporarily offline
        self.premium: Optional[Premium] = None
        self.user_is_logged_in: bool = False
        configure_logging(args)

        self.sleep_secs = args.sleep_secs
        if args.data_dir is None:
            self.data_dir = default_data_directory()
        else:
            self.data_dir = Path(args.data_dir)

        if not os.access(self.data_dir, os.W_OK | os.R_OK):
            raise SystemPermissionError(
                f'The given data directory {self.data_dir} is not readable or writable',
            )
        self.main_loop_spawned = False
        self.args = args
        self.api_task_greenlets: List[gevent.Greenlet] = []
        self.msg_aggregator = MessagesAggregator()
        self.greenlet_manager = GreenletManager(
            msg_aggregator=self.msg_aggregator)
        self.exchange_manager = ExchangeManager(
            msg_aggregator=self.msg_aggregator)
        # Initialize the AssetResolver singleton
        AssetResolver(data_directory=self.data_dir)
        self.data = DataHandler(self.data_dir, self.msg_aggregator)
        self.cryptocompare = Cryptocompare(data_directory=self.data_dir,
                                           database=None)
        self.coingecko = Coingecko(data_directory=self.data_dir)
        self.icon_manager = IconManager(data_dir=self.data_dir,
                                        coingecko=self.coingecko)
        self.greenlet_manager.spawn_and_track(
            after_seconds=None,
            task_name='periodically_query_icons_until_all_cached',
            exception_is_error=False,
            method=self.icon_manager.periodically_query_icons_until_all_cached,
            batch_size=ICONS_BATCH_SIZE,
            sleep_time_secs=ICONS_QUERY_SLEEP,
        )
        # Initialize the Inquirer singleton
        Inquirer(
            data_dir=self.data_dir,
            cryptocompare=self.cryptocompare,
            coingecko=self.coingecko,
        )
        # Keeps how many trades we have found per location. Used for free user limiting
        self.actions_per_location: Dict[str, Dict[Location, int]] = {
            'trade': defaultdict(int),
            'asset_movement': defaultdict(int),
        }

        self.lock.release()
        self.task_manager: Optional[TaskManager] = None
        self.shutdown_event = gevent.event.Event()

    def reset_after_failed_account_creation_or_login(self) -> None:
        """If the account creation or login failed make sure that the Rotki instance is clear

        Tricky instances are when after either failed premium credentials or user refusal
        to sync premium databases we relogged in.
        """
        self.cryptocompare.db = None

    def unlock_user(
        self,
        user: str,
        password: str,
        create_new: bool,
        sync_approval: Literal['yes', 'no', 'unknown'],
        premium_credentials: Optional[PremiumCredentials],
        initial_settings: Optional[ModifiableDBSettings] = None,
    ) -> None:
        """Unlocks an existing user or creates a new one if `create_new` is True

        May raise:
        - PremiumAuthenticationError if the password can't unlock the database.
        - AuthenticationError if premium_credentials are given and are invalid
        or can't authenticate with the server
        - DBUpgradeError if the rotki DB version is newer than the software or
        there is a DB upgrade and there is an error.
        - SystemPermissionError if the directory or DB file can not be accessed
        """
        log.info(
            'Unlocking user',
            user=user,
            create_new=create_new,
            sync_approval=sync_approval,
            initial_settings=initial_settings,
        )

        # unlock or create the DB
        self.password = password
        self.user_directory = self.data.unlock(user, password, create_new,
                                               initial_settings)
        self.data_importer = DataImporter(db=self.data.db)
        self.last_data_upload_ts = self.data.db.get_last_data_upload_ts()
        self.premium_sync_manager = PremiumSyncManager(data=self.data,
                                                       password=password)
        # set the DB in the external services instances that need it
        self.cryptocompare.set_database(self.data.db)

        # Anything that was set above here has to be cleaned in case of failure in the next step
        # by reset_after_failed_account_creation_or_login()
        try:
            self.premium = self.premium_sync_manager.try_premium_at_start(
                given_premium_credentials=premium_credentials,
                username=user,
                create_new=create_new,
                sync_approval=sync_approval,
            )
        except PremiumAuthenticationError:
            # Reraise it only if this is during the creation of a new account where
            # the premium credentials were given by the user
            if create_new:
                raise
            self.msg_aggregator.add_warning(
                'Could not authenticate the Rotki premium API keys found in the DB.'
                ' Has your subscription expired?', )
            # else let's just continue. User signed in succesfully, but he just
            # has unauthenticable/invalid premium credentials remaining in his DB

        settings = self.get_settings()
        self.greenlet_manager.spawn_and_track(
            after_seconds=None,
            task_name='submit_usage_analytics',
            exception_is_error=False,
            method=maybe_submit_usage_analytics,
            should_submit=settings.submit_usage_analytics,
        )
        self.etherscan = Etherscan(database=self.data.db,
                                   msg_aggregator=self.msg_aggregator)
        self.beaconchain = BeaconChain(database=self.data.db,
                                       msg_aggregator=self.msg_aggregator)
        eth_rpc_endpoint = settings.eth_rpc_endpoint
        # Initialize the price historian singleton
        PriceHistorian(
            data_directory=self.data_dir,
            cryptocompare=self.cryptocompare,
            coingecko=self.coingecko,
        )
        PriceHistorian().set_oracles_order(settings.historical_price_oracles)

        self.accountant = Accountant(
            db=self.data.db,
            user_directory=self.user_directory,
            msg_aggregator=self.msg_aggregator,
            create_csv=True,
        )

        # Initialize the rotkehlchen logger
        LoggingSettings(anonymized_logs=settings.anonymized_logs)
        exchange_credentials = self.data.db.get_exchange_credentials()
        self.exchange_manager.initialize_exchanges(
            exchange_credentials=exchange_credentials,
            database=self.data.db,
        )

        # Initialize blockchain querying modules
        ethereum_manager = EthereumManager(
            ethrpc_endpoint=eth_rpc_endpoint,
            etherscan=self.etherscan,
            database=self.data.db,
            msg_aggregator=self.msg_aggregator,
            greenlet_manager=self.greenlet_manager,
            connect_at_start=ETHEREUM_NODES_TO_CONNECT_AT_START,
        )
        kusama_manager = SubstrateManager(
            chain=SubstrateChain.KUSAMA,
            msg_aggregator=self.msg_aggregator,
            greenlet_manager=self.greenlet_manager,
            connect_at_start=KUSAMA_NODES_TO_CONNECT_AT_START,
            connect_on_startup=self._connect_ksm_manager_on_startup(),
            own_rpc_endpoint=settings.ksm_rpc_endpoint,
        )

        Inquirer().inject_ethereum(ethereum_manager)
        Inquirer().set_oracles_order(settings.current_price_oracles)

        self.chain_manager = ChainManager(
            blockchain_accounts=self.data.db.get_blockchain_accounts(),
            ethereum_manager=ethereum_manager,
            kusama_manager=kusama_manager,
            msg_aggregator=self.msg_aggregator,
            database=self.data.db,
            greenlet_manager=self.greenlet_manager,
            premium=self.premium,
            eth_modules=settings.active_modules,
            data_directory=self.data_dir,
            beaconchain=self.beaconchain,
            btc_derivation_gap_limit=settings.btc_derivation_gap_limit,
        )
        self.events_historian = EventsHistorian(
            user_directory=self.user_directory,
            db=self.data.db,
            msg_aggregator=self.msg_aggregator,
            exchange_manager=self.exchange_manager,
            chain_manager=self.chain_manager,
        )
        self.task_manager = TaskManager(
            max_tasks_num=DEFAULT_MAX_TASKS_NUM,
            greenlet_manager=self.greenlet_manager,
            api_task_greenlets=self.api_task_greenlets,
            database=self.data.db,
            cryptocompare=self.cryptocompare,
            premium_sync_manager=self.premium_sync_manager,
            chain_manager=self.chain_manager,
            exchange_manager=self.exchange_manager,
        )
        self.user_is_logged_in = True
        log.debug('User unlocking complete')

    def logout(self) -> None:
        if not self.user_is_logged_in:
            return

        user = self.data.username
        log.info(
            'Logging out user',
            user=user,
        )
        self.greenlet_manager.clear()
        del self.chain_manager
        self.exchange_manager.delete_all_exchanges()

        # Reset rotkehlchen logger to default
        LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS)

        del self.accountant
        del self.events_historian
        del self.data_importer

        if self.premium is not None:
            del self.premium
        self.data.logout()
        self.password = ''
        self.cryptocompare.unset_database()

        # Make sure no messages leak to other user sessions
        self.msg_aggregator.consume_errors()
        self.msg_aggregator.consume_warnings()
        self.task_manager = None

        self.user_is_logged_in = False
        log.info(
            'User successfully logged out',
            user=user,
        )

    def set_premium_credentials(self, credentials: PremiumCredentials) -> None:
        """
        Sets the premium credentials for Rotki

        Raises PremiumAuthenticationError if the given key is rejected by the Rotkehlchen server
        """
        log.info('Setting new premium credentials')
        if self.premium is not None:
            self.premium.set_credentials(credentials)
        else:
            self.premium = premium_create_and_verify(credentials)

        self.data.db.set_rotkehlchen_premium(credentials)

    def delete_premium_credentials(self) -> Tuple[bool, str]:
        """Deletes the premium credentials for Rotki"""
        msg = ''

        success = self.data.db.del_rotkehlchen_premium()
        if success is False:
            msg = 'The database was unable to delete the Premium keys for the logged-in user'
        self.deactivate_premium_status()
        return success, msg

    def deactivate_premium_status(self) -> None:
        """Deactivate premium in the current session"""
        self.premium = None
        self.premium_sync_manager.premium = None
        self.chain_manager.deactivate_premium_status()

    def start(self) -> gevent.Greenlet:
        assert not self.main_loop_spawned, 'Tried to spawn the main loop twice'
        greenlet = gevent.spawn(self.main_loop)
        self.main_loop_spawned = True
        return greenlet

    def main_loop(self) -> None:
        """Rotki main loop that fires often and runs the task manager's scheduler"""
        while self.shutdown_event.wait(
                timeout=MAIN_LOOP_SECS_DELAY) is not True:
            if self.task_manager is not None:
                self.task_manager.schedule()

    def get_blockchain_account_data(
        self,
        blockchain: SupportedBlockchain,
    ) -> Union[List[BlockchainAccountData], Dict[str, Any]]:
        account_data = self.data.db.get_blockchain_account_data(blockchain)
        if blockchain != SupportedBlockchain.BITCOIN:
            return account_data

        xpub_data = self.data.db.get_bitcoin_xpub_data()
        addresses_to_account_data = {x.address: x for x in account_data}
        address_to_xpub_mappings = self.data.db.get_addresses_to_xpub_mapping(
            list(addresses_to_account_data.keys()),  # type: ignore
        )

        xpub_mappings: Dict['XpubData', List[BlockchainAccountData]] = {}
        for address, xpub_entry in address_to_xpub_mappings.items():
            if xpub_entry not in xpub_mappings:
                xpub_mappings[xpub_entry] = []
            xpub_mappings[xpub_entry].append(
                addresses_to_account_data[address])

        data: Dict[str, Any] = {'standalone': [], 'xpubs': []}
        # Add xpub data
        for xpub_entry in xpub_data:
            data_entry = xpub_entry.serialize()
            addresses = xpub_mappings.get(xpub_entry, None)
            data_entry['addresses'] = addresses if addresses and len(
                addresses) != 0 else None
            data['xpubs'].append(data_entry)
        # Add standalone addresses
        for account in account_data:
            if account.address not in address_to_xpub_mappings:
                data['standalone'].append(account)

        return data

    def add_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        account_data: List[BlockchainAccountData],
    ) -> BlockchainBalancesUpdate:
        """Adds new blockchain accounts

        Adds the accounts to the blockchain instance and queries them to get the
        updated balances. Also adds them in the DB

        May raise:
        - EthSyncError from modify_blockchain_account
        - InputError if the given accounts list is empty.
        - TagConstraintError if any of the given account data contain unknown tags.
        - RemoteError if an external service such as Etherscan is queried and
          there is a problem with its query.
        """
        self.data.db.ensure_tags_exist(
            given_data=account_data,
            action='adding',
            data_type='blockchain accounts',
        )
        address_type = blockchain.get_address_type()
        updated_balances = self.chain_manager.add_blockchain_accounts(
            blockchain=blockchain,
            accounts=[address_type(entry.address) for entry in account_data],
        )
        self.data.db.add_blockchain_accounts(
            blockchain=blockchain,
            account_data=account_data,
        )

        return updated_balances

    def edit_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        account_data: List[BlockchainAccountData],
    ) -> None:
        """Edits blockchain accounts

        Edits blockchain account data for the given accounts

        May raise:
        - InputError if the given accounts list is empty or if
        any of the accounts to edit do not exist.
        - TagConstraintError if any of the given account data contain unknown tags.
        """
        # First check for validity of account data addresses
        if len(account_data) == 0:
            raise InputError(
                'Empty list of blockchain account data to edit was given')
        accounts = [x.address for x in account_data]
        unknown_accounts = set(accounts).difference(
            self.chain_manager.accounts.get(blockchain))
        if len(unknown_accounts) != 0:
            raise InputError(
                f'Tried to edit unknown {blockchain.value} '
                f'accounts {",".join(unknown_accounts)}', )

        self.data.db.ensure_tags_exist(
            given_data=account_data,
            action='editing',
            data_type='blockchain accounts',
        )

        # Finally edit the accounts
        self.data.db.edit_blockchain_accounts(
            blockchain=blockchain,
            account_data=account_data,
        )

    def remove_blockchain_accounts(
        self,
        blockchain: SupportedBlockchain,
        accounts: ListOfBlockchainAddresses,
    ) -> BlockchainBalancesUpdate:
        """Removes blockchain accounts

        Removes the accounts from the blockchain instance and queries them to get
        the updated balances. Also removes them from the DB

        May raise:
        - RemoteError if an external service such as Etherscan is queried and
          there is a problem with its query.
        - InputError if a non-existing account was given to remove
        """
        balances_update = self.chain_manager.remove_blockchain_accounts(
            blockchain=blockchain,
            accounts=accounts,
        )
        self.data.db.remove_blockchain_accounts(blockchain, accounts)
        return balances_update

    def get_history_query_status(self) -> Dict[str, str]:
        if self.events_historian.progress < FVal('100'):
            processing_state = self.events_historian.processing_state_name
            progress = self.events_historian.progress / 2
        elif self.accountant.first_processed_timestamp == -1:
            processing_state = 'Processing all retrieved historical events'
            progress = FVal(50)
        else:
            processing_state = 'Processing all retrieved historical events'
            # start_ts is min of the query start or the first action timestamp since action
            # processing can start well before query start to calculate cost basis
            start_ts = min(
                self.accountant.events.query_start_ts,
                self.accountant.first_processed_timestamp,
            )
            diff = self.accountant.events.query_end_ts - start_ts
            progress = 50 + 100 * (
                FVal(self.accountant.currently_processing_timestamp - start_ts)
                / FVal(diff) / 2)

        return {
            'processing_state': str(processing_state),
            'total_progress': str(progress)
        }

    def process_history(
        self,
        start_ts: Timestamp,
        end_ts: Timestamp,
    ) -> Tuple[Dict[str, Any], str]:
        (
            error_or_empty,
            history,
            loan_history,
            asset_movements,
            eth_transactions,
            defi_events,
            ledger_actions,
        ) = self.events_historian.get_history(
            start_ts=start_ts,
            end_ts=end_ts,
            has_premium=self.premium is not None,
        )
        result = self.accountant.process_history(
            start_ts=start_ts,
            end_ts=end_ts,
            trade_history=history,
            loan_history=loan_history,
            asset_movements=asset_movements,
            eth_transactions=eth_transactions,
            defi_events=defi_events,
            ledger_actions=ledger_actions,
        )
        return result, error_or_empty

    @overload
    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['trade'],
        location_actions: TRADES_LIST,
        all_actions: TRADES_LIST,
    ) -> TRADES_LIST:
        ...

    @overload
    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['asset_movement'],
        location_actions: List[AssetMovement],
        all_actions: List[AssetMovement],
    ) -> List[AssetMovement]:
        ...

    def _apply_actions_limit(
        self,
        location: Location,
        action_type: Literal['trade', 'asset_movement'],
        location_actions: Union[TRADES_LIST, List[AssetMovement]],
        all_actions: Union[TRADES_LIST, List[AssetMovement]],
    ) -> Union[TRADES_LIST, List[AssetMovement]]:
        """Take as many actions from location actions and add them to all actions as the limit permits

        Returns the modified (or not) all_actions
        """
        # If we are already at or above the limit return current actions disregarding this location
        actions_mapping = self.actions_per_location[action_type]
        current_num_actions = sum(x for _, x in actions_mapping.items())
        limit = LIMITS_MAPPING[action_type]
        if current_num_actions >= limit:
            return all_actions

        # Find out how many more actions can we return, and depending on that get
        # the number of actions from the location actions and add them to the total
        remaining_num_actions = limit - current_num_actions
        if remaining_num_actions < 0:
            remaining_num_actions = 0

        num_actions_to_take = min(len(location_actions), remaining_num_actions)

        actions_mapping[location] = num_actions_to_take
        all_actions.extend(
            location_actions[0:num_actions_to_take])  # type: ignore
        return all_actions

    def query_trades(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Optional[Location],
    ) -> TRADES_LIST:
        """Queries trades for the given location and time range.
        If no location is given then all external, all exchange and DEX trades are queried.

        DEX Trades are queried only if the user has premium
        If the user does not have premium then a trade limit is applied.

        May raise:
        - RemoteError: If there are problems connecting to any of the remote exchanges
        """
        trades: TRADES_LIST
        if location is not None:
            trades = self.query_location_trades(from_ts, to_ts, location)
        else:
            trades = self.query_location_trades(from_ts, to_ts,
                                                Location.EXTERNAL)
            # crypto.com is not an API key supported exchange but user can import from CSV
            trades.extend(
                self.query_location_trades(from_ts, to_ts, Location.CRYPTOCOM))
            for name, exchange in self.exchange_manager.connected_exchanges.items(
            ):
                exchange_trades = exchange.query_trade_history(
                    start_ts=from_ts, end_ts=to_ts)
                if self.premium is None:
                    trades = self._apply_actions_limit(
                        location=deserialize_location(name),
                        action_type='trade',
                        location_actions=exchange_trades,
                        all_actions=trades,
                    )
                else:
                    trades.extend(exchange_trades)

            # for all trades we also need uniswap trades
            if self.premium is not None:
                uniswap = self.chain_manager.uniswap
                if uniswap is not None:
                    trades.extend(
                        uniswap.get_trades(
                            addresses=self.chain_manager.
                            queried_addresses_for_module('uniswap'),
                            from_timestamp=from_ts,
                            to_timestamp=to_ts,
                        ), )

        # return trades with most recent first
        trades.sort(key=lambda x: x.timestamp, reverse=True)
        return trades

    def query_location_trades(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Location,
    ) -> TRADES_LIST:
        # clear the trades queried for this location
        self.actions_per_location['trade'][location] = 0

        location_trades: TRADES_LIST
        if location in (Location.EXTERNAL, Location.CRYPTOCOM):
            location_trades = self.data.db.get_trades(  # type: ignore  # list invariance
                from_ts=from_ts,
                to_ts=to_ts,
                location=location,
            )
        elif location == Location.UNISWAP:
            if self.premium is not None:
                uniswap = self.chain_manager.uniswap
                if uniswap is not None:
                    location_trades = uniswap.get_trades(  # type: ignore  # list invariance
                        addresses=self.chain_manager.
                        queried_addresses_for_module('uniswap'),
                        from_timestamp=from_ts,
                        to_timestamp=to_ts,
                    )
        else:
            # should only be an exchange
            exchange = self.exchange_manager.get(str(location))
            if not exchange:
                logger.warning(
                    f'Tried to query trades from {location} which is either not an '
                    f'exchange or not an exchange the user has connected to', )
                return []

            location_trades = exchange.query_trade_history(start_ts=from_ts,
                                                           end_ts=to_ts)

        trades: TRADES_LIST = []
        if self.premium is None:
            trades = self._apply_actions_limit(
                location=location,
                action_type='trade',
                location_actions=location_trades,
                all_actions=trades,
            )
        else:
            trades = location_trades

        return trades

    def query_balances(
        self,
        requested_save_data: bool = False,
        timestamp: Timestamp = None,
        ignore_cache: bool = False,
    ) -> Dict[str, Any]:
        """Query all balances rotkehlchen can see.

        If requested_save_data is True then the data are always saved in the DB,
        if it is False then data are saved if self.data.should_save_balances()
        is True.
        If timestamp is None then the current timestamp is used.
        If a timestamp is given then that is the time that the balances are going
        to be saved in the DB
        If ignore_cache is True then all underlying calls that have a cache ignore it

        Returns a dictionary with the queried balances.
        """
        log.info('query_balances called',
                 requested_save_data=requested_save_data)

        balances: Dict[str, Dict[Asset, Balance]] = {}
        problem_free = True
        for _, exchange in self.exchange_manager.connected_exchanges.items():
            exchange_balances, _ = exchange.query_balances(
                ignore_cache=ignore_cache)
            # If we got an error, disregard that exchange but make sure we don't save data
            if not isinstance(exchange_balances, dict):
                problem_free = False
            else:
                balances[exchange.name] = exchange_balances

        liabilities: Dict[Asset, Balance]
        try:
            blockchain_result = self.chain_manager.query_balances(
                blockchain=None,
                force_token_detection=ignore_cache,
                ignore_cache=ignore_cache,
            )
            balances[str(
                Location.BLOCKCHAIN)] = blockchain_result.totals.assets
            liabilities = blockchain_result.totals.liabilities
        except (RemoteError, EthSyncError) as e:
            problem_free = False
            liabilities = {}
            log.error(f'Querying blockchain balances failed due to: {str(e)}')

        balances = account_for_manually_tracked_balances(db=self.data.db,
                                                         balances=balances)

        # Calculate usd totals
        assets_total_balance: DefaultDict[Asset,
                                          Balance] = defaultdict(Balance)
        total_usd_per_location: Dict[str, FVal] = {}
        for location, asset_balance in balances.items():
            total_usd_per_location[location] = ZERO
            for asset, balance in asset_balance.items():
                assets_total_balance[asset] += balance
                total_usd_per_location[location] += balance.usd_value

        net_usd = sum((balance.usd_value
                       for _, balance in assets_total_balance.items()), ZERO)
        liabilities_total_usd = sum(
            (liability.usd_value for _, liability in liabilities.items()),
            ZERO)  # noqa: E501
        net_usd -= liabilities_total_usd

        # Calculate location stats
        location_stats: Dict[str, Any] = {}
        for location, total_usd in total_usd_per_location.items():
            if location == str(Location.BLOCKCHAIN):
                total_usd -= liabilities_total_usd

            percentage = (total_usd /
                          net_usd).to_percentage() if net_usd != ZERO else '0%'
            location_stats[location] = {
                'usd_value': total_usd,
                'percentage_of_net_value': percentage,
            }

        # Calculate 'percentage_of_net_value' per asset
        assets_total_balance_as_dict: Dict[Asset, Dict[str, Any]] = {
            asset: balance.to_dict()
            for asset, balance in assets_total_balance.items()
        }
        liabilities_as_dict: Dict[Asset, Dict[str, Any]] = {
            asset: balance.to_dict()
            for asset, balance in liabilities.items()
        }
        for asset, balance_dict in assets_total_balance_as_dict.items():
            percentage = (balance_dict['usd_value'] / net_usd).to_percentage(
            ) if net_usd != ZERO else '0%'  # noqa: E501
            assets_total_balance_as_dict[asset][
                'percentage_of_net_value'] = percentage

        for asset, balance_dict in liabilities_as_dict.items():
            percentage = (balance_dict['usd_value'] / net_usd).to_percentage(
            ) if net_usd != ZERO else '0%'  # noqa: E501
            liabilities_as_dict[asset]['percentage_of_net_value'] = percentage

        # Compose balances response
        result_dict = {
            'assets': assets_total_balance_as_dict,
            'liabilities': liabilities_as_dict,
            'location': location_stats,
            'net_usd': net_usd,
        }
        allowed_to_save = requested_save_data or self.data.should_save_balances(
        )

        if problem_free and allowed_to_save:
            if not timestamp:
                timestamp = Timestamp(int(time.time()))
            self.data.db.save_balances_data(data=result_dict,
                                            timestamp=timestamp)
            log.debug('query_balances data saved')
        else:
            log.debug(
                'query_balances data not saved',
                allowed_to_save=allowed_to_save,
                problem_free=problem_free,
            )

        return result_dict

    def _query_exchange_asset_movements(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        all_movements: List[AssetMovement],
        exchange: Union[ExchangeInterface, Location],
    ) -> List[AssetMovement]:
        if isinstance(exchange, ExchangeInterface):
            location = deserialize_location(exchange.name)
            # clear the asset movements queried for this exchange
            self.actions_per_location['asset_movement'][location] = 0
            location_movements = exchange.query_deposits_withdrawals(
                start_ts=from_ts,
                end_ts=to_ts,
            )
        else:
            assert isinstance(exchange,
                              Location), 'only a location should make it here'
            assert exchange == Location.CRYPTOCOM, 'only cryptocom should make it here'
            location = exchange
            # cryptocom has no exchange integration but we may have DB entries
            self.actions_per_location['asset_movement'][location] = 0
            location_movements = self.data.db.get_asset_movements(
                from_ts=from_ts,
                to_ts=to_ts,
                location=location,
            )

        movements: List[AssetMovement] = []
        if self.premium is None:
            movements = self._apply_actions_limit(
                location=location,
                action_type='asset_movement',
                location_actions=location_movements,
                all_actions=all_movements,
            )
        else:
            all_movements.extend(location_movements)
            movements = all_movements

        return movements

    def query_asset_movements(
        self,
        from_ts: Timestamp,
        to_ts: Timestamp,
        location: Optional[Location],
    ) -> List[AssetMovement]:
        """Queries AssetMovements for the given location and time range.

        If no location is given then all exchange asset movements are queried.
        If the user does not have premium then a limit is applied.
        May raise:
        - RemoteError: If there are problems connecting to any of the remote exchanges
        """
        movements: List[AssetMovement] = []
        if location is not None:
            if location == Location.CRYPTOCOM:
                movements = self._query_exchange_asset_movements(
                    from_ts=from_ts,
                    to_ts=to_ts,
                    all_movements=movements,
                    exchange=Location.CRYPTOCOM,
                )
            else:
                exchange = self.exchange_manager.get(str(location))
                if not exchange:
                    logger.warning(
                        f'Tried to query deposits/withdrawals from {location} which is either '
                        f'not at exchange or not an exchange the user has connected to',
                    )
                    return []
                movements = self._query_exchange_asset_movements(
                    from_ts=from_ts,
                    to_ts=to_ts,
                    all_movements=movements,
                    exchange=exchange,
                )
        else:
            # cryptocom has no exchange integration but we may have DB entries due to csv import
            movements = self._query_exchange_asset_movements(
                from_ts=from_ts,
                to_ts=to_ts,
                all_movements=movements,
                exchange=Location.CRYPTOCOM,
            )
            for _, exchange in self.exchange_manager.connected_exchanges.items(
            ):
                self._query_exchange_asset_movements(
                    from_ts=from_ts,
                    to_ts=to_ts,
                    all_movements=movements,
                    exchange=exchange,
                )

        # return movements with most recent first
        movements.sort(key=lambda x: x.timestamp, reverse=True)
        return movements

    def set_settings(self, settings: ModifiableDBSettings) -> Tuple[bool, str]:
        """Tries to set new settings. Returns True in success or False with message if error"""
        with self.lock:
            if settings.eth_rpc_endpoint is not None:
                result, msg = self.chain_manager.set_eth_rpc_endpoint(
                    settings.eth_rpc_endpoint)
                if not result:
                    return False, msg

            if settings.ksm_rpc_endpoint is not None:
                result, msg = self.chain_manager.set_ksm_rpc_endpoint(
                    settings.ksm_rpc_endpoint)
                if not result:
                    return False, msg

            if settings.kraken_account_type is not None:
                kraken = self.exchange_manager.get('kraken')
                if kraken:
                    kraken.set_account_type(
                        settings.kraken_account_type)  # type: ignore

            if settings.btc_derivation_gap_limit is not None:
                self.chain_manager.btc_derivation_gap_limit = settings.btc_derivation_gap_limit

            if settings.current_price_oracles is not None:
                Inquirer().set_oracles_order(settings.current_price_oracles)

            if settings.historical_price_oracles is not None:
                PriceHistorian().set_oracles_order(
                    settings.historical_price_oracles)

            self.data.db.set_settings(settings)
            return True, ''

    def get_settings(self) -> DBSettings:
        """Returns the db settings with a check whether premium is active or not"""
        db_settings = self.data.db.get_settings(
            have_premium=self.premium is not None)
        return db_settings

    def setup_exchange(
        self,
        name: str,
        api_key: ApiKey,
        api_secret: ApiSecret,
        passphrase: Optional[str] = None,
    ) -> Tuple[bool, str]:
        """
        Setup a new exchange with an api key and an api secret and optionally a passphrase

        By default the api keys are always validated unless validate is False.
        """
        is_success, msg = self.exchange_manager.setup_exchange(
            name=name,
            api_key=api_key,
            api_secret=api_secret,
            database=self.data.db,
            passphrase=passphrase,
        )

        if is_success:
            # Success, save the result in the DB
            self.data.db.add_exchange(name,
                                      api_key,
                                      api_secret,
                                      passphrase=passphrase)
        return is_success, msg

    def remove_exchange(self, name: str) -> Tuple[bool, str]:
        if not self.exchange_manager.has_exchange(name):
            return False, 'Exchange {} is not registered'.format(name)

        self.exchange_manager.delete_exchange(name)
        # Success, remove it also from the DB
        self.data.db.remove_exchange(name)
        self.data.db.delete_used_query_range_for_exchange(name)
        return True, ''

    def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]:
        """Query for frequently changing data"""
        result: Dict[str, Union[bool, Timestamp]] = {}

        if self.user_is_logged_in:
            result[
                'last_balance_save'] = self.data.db.get_last_balance_save_time(
                )
            result[
                'eth_node_connection'] = self.chain_manager.ethereum.web3_mapping.get(
                    NodeName.OWN, None) is not None  # noqa : E501
            result['last_data_upload_ts'] = Timestamp(
                self.premium_sync_manager.last_data_upload_ts)  # noqa : E501
        return result

    def shutdown(self) -> None:
        self.logout()
        self.shutdown_event.set()

    def _connect_ksm_manager_on_startup(self) -> bool:
        return bool(self.data.db.get_blockchain_accounts().ksm)

    def create_oracle_cache(
        self,
        oracle: HistoricalPriceOracle,
        from_asset: Asset,
        to_asset: Asset,
        purge_old: bool,
    ) -> None:
        """Creates the cache of the given asset pair from the start of time
        until now for the given oracle.

        if purge_old is true then any old cache in memory and in a file is purged

        May raise:
            - RemoteError if there is a problem reaching the oracle
            - UnsupportedAsset if any of the two assets is not supported by the oracle
        """
        if oracle != HistoricalPriceOracle.CRYPTOCOMPARE:
            return  # only for cryptocompare for now

        self.cryptocompare.create_cache(from_asset, to_asset, purge_old)

    def delete_oracle_cache(
        self,
        oracle: HistoricalPriceOracle,
        from_asset: Asset,
        to_asset: Asset,
    ) -> None:
        if oracle != HistoricalPriceOracle.CRYPTOCOMPARE:
            return  # only for cryptocompare for now

        self.cryptocompare.delete_cache(from_asset, to_asset)

    def get_oracle_cache(
            self, oracle: HistoricalPriceOracle) -> List[Dict[str, Any]]:
        if oracle != HistoricalPriceOracle.CRYPTOCOMPARE:
            return []  # only for cryptocompare for now

        return self.cryptocompare.get_all_cache_data()
Beispiel #44
0
class Spider:
    def __init__(self, student_shortcut, job_manager, db_manager):
        self.shortcut = student_shortcut
        self.job_manager = job_manager
        self.db_manager = db_manager

        self.term_course_mutex = Semaphore()
        self.term_course_flags = set()

        # 避免重复写数据库
        # 学期, 专业, 计划, 班级学生关系记录 不会重复抓取
        # 课程记录使用 课程代码 区分
        # 课程学期教学班使用 学期代码 + 课程代码 区分
        # 学生记录使用 学号 区分
        # 三者不会有相同情况, 只用一个集合即可
        # 在这里没有使用锁, 因此这个标记不一定会命中
        # 使用锁会增加很多等待时间, 还不如直接更新数据库
        # 并没有提升多少性能
        # self._flags = {
        #     'course': ('课程代码', Semaphore(), set()),
        #     'student': ('学号', Semaphore(), set()),
        # }
        # 统计信息: 内存标志命中次数
        # self._flags_hit_count = {
        #     'course': 0,
        #     'student': 0,
        # }

    @elapse_dec
    def crawl(self, dfs_mode=False):
        self.job_manager.jobs = (self.iter_term_and_major,
                                 self.iter_term_and_course,
                                 self.iter_teaching_class, self.sync_students)
        logger.info('Crawl start!'.center(72, '='))
        self.job_manager.start(dfs_mode)
        logger.info(
            'Jobs are all dispatched. Waiting for database requests handling.')
        self.db_manager.join()
        logger.info('Crawl finished!'.center(72, '='))
        report_db(self.db_manager.db)

    # 以下是任务
    def iter_term_and_major(self):
        # @structure {'专业': [{'专业代码': str, '专业名称': str}], '学期': [{'学期代码': str, '学期名称': str}]}
        code = self.shortcut.get_code()
        terms = code['学期']
        majors = code['专业']

        for term in terms:
            term_code = term['学期代码']
            self.db_manager.request(
                'term', InsertOnNotExist({'学期代码': term_code}, term))
            yield term_code, None

        max_term_number = int(terms[-1]['学期代码'])
        # 一些专业被删掉了, 因此很多记录都没了= =
        for major in majors:
            major_code = major['专业代码']
            self.db_manager.request(
                'major', InsertOnNotExist({'专业代码': major_code}, major))
            for i in term_range(major['专业名称'], max_term_number):
                term_code = '%03d' % i
                yield term_code, major_code

    def iter_term_and_course(self, term_code, major_code=None):
        if major_code:
            courses = self.shortcut.get_teaching_plan(xqdm=term_code,
                                                      zydm=major_code)
            for course in courses:
                course_code = course['课程代码']
                self.db_manager.request(
                    'course', InsertOnNotExist({'课程代码': course_code}, course))

                plan_doc = {
                    '课程代码': course_code,
                    '学期代码': term_code,
                    '专业代码': major_code
                }
                self.db_manager.request('plan',
                                        InsertOnNotExist(plan_doc, plan_doc))

                yield term_code, course_code
        else:
            courses = self.shortcut.get_teaching_plan(xqdm=term_code, kclx='x')
            for course in courses:
                course_code = course['课程代码']
                self.db_manager.request(
                    'course', InsertOnNotExist({'课程代码': course_code}, course))
                yield term_code, course_code

    def iter_teaching_class(self,
                            term_code,
                            course_code=None,
                            course_name=None):
        if course_code is None:
            is_new = True
        else:
            key = term_code + course_code
            self.term_course_mutex.acquire()
            is_new = key not in self.term_course_flags
            if is_new:
                self.term_course_flags.add(key)
            # 在 if 内释放锁会导致出现重复键时锁无法释放
            self.term_course_mutex.release()

        if is_new:
            # @structure [{'任课教师': str, '课程名称': str, '教学班号': str, 'c': str, '班级容量': int}]
            classes = self.shortcut.search_course(xqdm=term_code,
                                                  kcdm=course_code,
                                                  kcmc=course_name)
            for teaching_class in classes:
                course_code = teaching_class['课程代码']
                class_code = teaching_class['教学班号']
                # @structure {'校区': str,'开课单位': str,'考核类型': str,'课程类型': str,'课程名称': str,'教学班号': str,
                # '起止周': str, '时间地点': str,'学分': float,'性别限制': str,'优选范围': str,'禁选范围': str,'选中人数': int,'备 注': str}
                class_info = self.db_manager.db['class'].find_one({
                    '学期代码':
                    term_code,
                    '课程代码':
                    course_code,
                    '教学班号':
                    class_code
                })
                if not class_info:
                    class_info = self.shortcut.get_class_info(xqdm=term_code,
                                                              kcdm=course_code,
                                                              jxbh=class_code)
                    class_info.update(teaching_class)
                    # 接口没有学期代码参数
                    class_info['学期代码'] = term_code

                    self.db_manager.request('class', InsertOne(class_info))
                yield term_code, course_code, class_code

    def sync_students(self, term_code, course_code, class_code):
        # @structure {'学期': str, '班级名称': str, '学生': [{'姓名': str, '学号': int}]}
        students = self.shortcut.get_class_students(xqdm=term_code,
                                                    kcdm=course_code,
                                                    jxbh=class_code)
        # 可能没有结果
        if students:
            students = students['学生']
            for student in students:
                student_code = student['学号']
                student_name = student['姓名']
                student['性别'] = '女' if student_name.endswith('*') else '男'
                student['姓名'] = student_name.rstrip('*')

                self.db_manager.request(
                    'student', InsertOnNotExist({'学号': student_code}, student))

                class_student_doc = {
                    '学期代码': term_code,
                    '课程代码': course_code,
                    '教学班号': class_code,
                    '学号': student_code
                }
                self.db_manager.request(
                    'class_student',
                    InsertOnNotExist(class_student_doc, class_student_doc))
class TransportChannel(object):
    def __init__(self, uplink=None, downlink=None):
        self.log = logging.getLogger("{module}.{name}".format(
            module=self.__class__.__module__, name=self.__class__.__name__))
        
        self.downlink = downlink
        self.uplink = uplink
        self.recv_callback = None

        self.context = zmq.Context()
        self.poller = zmq.Poller()

        self.ul_socket = self.context.socket(zmq.SUB) # one SUB socket for uplink communication over topics
        if sys.version_info.major >= 3:
            self.ul_socket.setsockopt_string(zmq.SUBSCRIBE,  "NEW_NODE")
            self.ul_socket.setsockopt_string(zmq.SUBSCRIBE,  "NODE_EXIT")
        else:
            self.ul_socket.setsockopt(zmq.SUBSCRIBE,  "NEW_NODE")
            self.ul_socket.setsockopt(zmq.SUBSCRIBE,  "NODE_EXIT")

        self.downlinkSocketLock = Semaphore(value=1)
        self.dl_socket = self.context.socket(zmq.PUB) # one PUB socket for downlink communication over topics

        #register UL socket in poller
        self.poller.register(self.ul_socket, zmq.POLLIN)

        self.importedPbClasses = {}


    def set_downlink(self, downlink):
        self.log.debug("Set Downlink: {}".format(downlink))
        self.downlink = downlink


    def set_uplink(self, uplink):
        self.log.debug("Set Uplink: {}".format(uplink))
        self.uplink = uplink


    def subscribe_to(self, topic):
        self.log.debug("Controller subscribes to topic: {}".format(topic))
        if sys.version_info.major >= 3:
            self.ul_socket.setsockopt_string(zmq.SUBSCRIBE, topic)
        else:
            self.ul_socket.setsockopt(zmq.SUBSCRIBE, topic)
 

    def set_recv_callback(self, callback):
        self.recv_callback = callback


    def send_downlink_msg(self, msgContainer):
        msgContainer[0] = msgContainer[0].encode('utf-8')
        cmdDesc = msgContainer[1]
        msg = msgContainer[2]

        if cmdDesc.serialization_type == msgs.CmdDesc.PICKLE:
            try:
                msg = pickle.dumps(msg)
            except:
                msg = dill.dumps(msg)
        elif cmdDesc.serialization_type == msgs.CmdDesc.PROTOBUF:
            msg = msg.SerializeToString()

        msgContainer[1] = cmdDesc.SerializeToString()
        msgContainer[2] = msg

        self.downlinkSocketLock.acquire()
        try:
            self.dl_socket.send_multipart(msgContainer)
        finally:
            self.downlinkSocketLock.release()

    def deserialize_protobuff(self, message, typ):
        class_ = self.importedPbClasses.get(typ, None)

        if class_ is None:
            module_, class_ = typ.rsplit('.', 1)
            class_ = getattr(import_module(module_), class_)
            self.importedPbClasses[typ] = class_

        rv = class_()
        rv.ParseFromString(message)
        return rv

    def start_receiving(self):
        socks = dict(self.poller.poll())
        if self.ul_socket in socks and socks[self.ul_socket] == zmq.POLLIN:
            try:
                msgContainer = self.ul_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError:
                raise zmq.ZMQError

            assert len(msgContainer) == 3, msgContainer

            dest = msgContainer[0]
            cmdDesc = msgs.CmdDesc()
            cmdDesc.ParseFromString(msgContainer[1])
            msg = msgContainer[2]
            if cmdDesc.serialization_type == msgs.CmdDesc.PICKLE:
                try:
                    msg = pickle.loads(msg)
                except:
                    msg = dill.loads(msg)
            elif cmdDesc.serialization_type == msgs.CmdDesc.PROTOBUF:
                if cmdDesc.pb_full_name:
                    msg = self.deserialize_protobuff(msg, cmdDesc.pb_full_name)

            msgContainer[0] = dest.decode('utf-8')
            msgContainer[1] = cmdDesc
            msgContainer[2] = msg
            self.recv_callback(msgContainer)


    def start(self):
        self.log.debug("Controller on DL-{}, UP-{}".format(self.downlink, self.uplink))
        self.dl_socket.bind(self.downlink)
        self.ul_socket.bind(self.uplink)


    def stop(self):
        self.ul_socket.setsockopt(zmq.LINGER, 0)
        self.dl_socket.setsockopt(zmq.LINGER, 0)
        self.ul_socket.close()
        self.dl_socket.close()
        self.context.term() 
Beispiel #46
0
class Stream(object):
    """
    参考tornado的iostream
    """

    read_lock = None
    write_lock = None

    reading = False
    writing = False

    def __init__(self, sock, max_buffer_size=-1,
                 read_chunk_size=None, use_gevent=False, lock_mode=LOCK_MODE_RDWR):
        self.sock = sock
        self.max_buffer_size = max_buffer_size
        self.read_chunk_size = read_chunk_size or READ_CHUNK_SIZE

        self._read_buffer = collections.deque()
        self._read_buffer_size = 0
        self._read_delimiter = None
        self._read_regex = None
        self._read_bytes = None
        self._read_until_close = False
        self._read_checker = None

        if use_gevent:
            from gevent.lock import Semaphore as Lock
        else:
            from threading import Lock

        if lock_mode & LOCK_MODE_READ:
            self.read_lock = Lock()

        if lock_mode & LOCK_MODE_WRITE:
            self.write_lock = Lock()

    def close(self, exc_info=False):
        if self.closed():
            # 如果已经关闭过,就直接返回了
            return

        self.close_fd()

    def shutdown(self, how=2):
        """
        gevent的close只是把sock替换为另一个类的实例。
        这个实例的任何方法都会报错,但只有当真正调用recv、write或者有recv or send事件的时候,才会调用到这些函数,才可能检测到。
        而我们在endpoint对应的函数里spawn_later一个新greenlet而不做join的话,connection的while循环此时已经开始read了。

        之所以不把这个函数实现到connection,是因为shutdown更类似于触发一个close的事件
        用shutdown可以直接触发. how: 0: SHUT_RD, 1: SHUT_WR, else: all
        shutdown后还是会触发close事件以及相关的回调函数,不必担心
        """
        if self.closed():
            return

        self.shutdown_fd(how)

    @lock_read
    def read_until_regex(self, regex):
        """Run when we read the given regex pattern.
        """
        self._read_regex = re.compile(regex)
        while 1:
            ret, data = self._try_inline_read()
            if ret <= 0:
                return data

    @lock_read
    def read_until(self, delimiter):
        """
        为了兼容调用的方法
        """

        self._read_delimiter = delimiter
        while 1:
            ret, data = self._try_inline_read()
            if ret <= 0:
                return data

    @lock_read
    def read_bytes(self, num_bytes):
        """Run when we read the given number of bytes.
        """
        assert isinstance(num_bytes, numbers.Integral)
        self._read_bytes = num_bytes
        while 1:
            ret, data = self._try_inline_read()
            if ret <= 0:
                return data

    @lock_read
    def read_until_close(self):
        """Reads all data from the socket until it is closed.
        """
        if self.closed():
            return self._consume(self._read_buffer_size)

        self._read_until_close = True
        while 1:
            ret, data = self._try_inline_read()
            if ret <= 0:
                return data

    @lock_read
    def read_with_checker(self, checker):
        """
        checker(buf):
            0 继续接收
            >0 使用的长度
            <0 异常
        """

        self._read_checker = checker
        while 1:
            ret, data = self._try_inline_read()
            if ret <= 0:
                return data

    @lock_write
    def write(self, data):
        """
        写数据
        """

        if self.closed():
            return False

        while data:
            num_bytes = self.write_to_fd(data)

            if num_bytes is None:
                return False

            data = data[num_bytes:]

        return True

    def closed(self):
        return not self.sock

    def acquire_read_lock(self):
        self.reading = True
        if self.read_lock:
            self.read_lock.acquire()

    def release_read_lock(self):
        self.reading = False
        if self.read_lock:
            self.read_lock.release()

    def acquire_write_lock(self):
        self.writing = True
        if self.write_lock:
            self.write_lock.acquire()

    def release_write_lock(self):
        self.writing = False
        if self.write_lock:
            self.write_lock.release()

    def read_from_fd(self):
        """
        从fd里读取数据。
        超时不捕获异常,由外面捕获。
        中断也不应该关闭链接,比如kill -USR1 会抛出中断异常
        其他错误则直接关闭连接
        :return:
        """
        try:
            chunk = self.sock.recv(self.read_chunk_size)
        except socket.timeout, e:
            # 服务器是不会recv超时的
            raise e
        except socket.error, e:
            if e.errno == errno.EINTR:
                # 中断,返回空字符串,但不断掉连接
                return ''
            else:
                # Connection reset by peer 的原因说明:
                # 网上说是对端非正常关闭连接,比如对端程序异常退出之类
                # 我重现的方法是: C向S发送数据,如果S有回应,而C没有读取,C就调用close或者被析构的话
                logger.error('exc occur.', exc_info=True)
                self.close()
                return None
Beispiel #47
0
class Match(Actor):

    '''**Pattern matching on a key/value document stream.**

    This module routes messages to a queue associated to the matching rule
    set.  The event['data'] payload has to be of <type 'dict'>.  Typically,
    the source data is JSON converted to a Python dictionary.

    The match rules can be either stored on disk or directly defined into the
    bootstrap file.

    A match rule is written in YAML syntax and consists out of 2 parts:

    - condition:

        A list of dictionaries holding with the individual conditions which
        ALL have to match for the complete rule to match.

    ::

        re:     Regex matching
        !re:    Negative regex matching
        >:      Bigger than
        >=:     Bigger or equal than
        <:      Smaller than
        <=:     Smaller or equal than
        =:      Equal than
        in:     Evaluate list membership
        !in:    Evaluate negative list membership


    - queue:

        The queue section contains a list of dictionaries/maps each containing
        1 key with another dictionary/map as a value.  These key/value pairs
        are added to the *header section* of the event and stored under the
        queue name key.
        If you are not interested in adding any information to the header you
        can leave the dictionary empty.  So this would be valid:

    All rules will be evaluated sequentially in no particular order.  When a
    rule matches, evaluation the other rules will continue untill all rules
    are processed.

    Examples
    ~~~~~~~~

    This example would route the events - with field "greeting" containing
    the value "hello" - to the outbox queue without adding any information
    to the header of the event itself.

    ::

        condition:
            - greeting: re:^hello$

        queue:
            - outbox:



    This example combines multiple conditions and stores 4 variables under
    event["header"][self.name] while submitting the event to the modules'
    **email** queue.

    ::

        condition:
            - check_command: re:check:host.alive
            - hostproblemid: re:\d*
            - hostgroupnames: in:tag:development

        queue:
            - email:
                from: [email protected]
                to:
                    - [email protected]
                subject: UMI - Host  {{ hostname }} is  {{ hoststate }}.
                template: host_email_alert



    Parameters:

        - name(str)
           |  The name of the module.

        - size(int)
           |  The default max length of each queue.

        - frequency(int)
           |  The frequency in seconds to generate metrics.

        - location(str)("")
           |  The directory containing rules.
           |  If empty, no rules are read from disk.

        - rules(dict)({})
           |  A dict of rules in the above described format.
           |  For example:
           |  {"omg": {"condition": [{"greeting": "re:^hello$"}], "queue": [{"outbox": {"one": 1}}]}}


    Queues:

        - inbox
           |  Incoming events

        - <queue_name>
           |  The queue which matches a rule

        - nomatch
           |  The queue receiving event without matches

    '''

    def __init__(self, actor_config, location="", rules={}):
        Actor.__init__(self, actor_config)

        self.pool.createQueue("inbox")
        self.pool.createQueue("nomatch")
        self.registerConsumer(self.consume, "inbox")

        self.__active_rules = {}
        self.match = MatchRules()
        self.rule_lock = Semaphore()

    def preHook(self):

        if self.kwargs.location != "":
            self.createDir()
            self.logging.info("Rules directoy '%s' defined." % (self.kwargs.location))
            self.sendToBackground(self.monitorRuleDirectory)
        else:
            self.__active_rules.update(self.uplook.dump()["rules"])
            self.logging.info("No rules directory defined, not reading rules from disk.")

    def createDir(self):

        if os.path.exists(self.kwargs.location):
            if not os.path.isdir(self.kwargs.location):
                raise Exception("%s exists but is not a directory" % (self.kwargs.location))
            else:
                self.logging.info("Directory %s exists so I'm using it." % (self.kwargs.location))
        else:
            self.logging.info("Directory %s does not exist so I'm creating it." % (self.kwargs.location))
            os.makedirs(self.kwargs.location)

    def monitorRuleDirectory(self):

        '''
        Loads new rules when changes happen.
        '''

        self.logging.info("Monitoring directory '%s' for changes" % (self.kwargs.location))
        self.readRules = ReadRulesDisk(self.logging, self.kwargs.location)

        new_rules = self.readRules.readDirectory()
        new_rules.update(self.kwargs.rules)
        self.__active_rules = new_rules
        self.logging.info("Read %s rules from disk and %s defined in config." % (len(new_rules), len(self.kwargs.rules)))
        while self.loop():
            try:
                new_rules = self.readRules.waitForChanges()
                new_rules.update(self.kwargs.rules)
                if cmp(self.__active_rules, new_rules) != 0:
                    self.rule_lock.acquire()
                    self.__active_rules = new_rules
                    self.rule_lock.release()
                    self.logging.info("Read %s rules from disk and %s defined in config." % (len(new_rules), len(self.kwargs.rules)))
            except Exception as err:
                self.logging.warning("Problem reading rules directory.  Reason: %s" % (err))
                sleep(1)

    def consume(self, event):
        '''Submits matching documents to the defined queue along with
        the defined header.'''

        if isinstance(event.data, dict):
            self.rule_lock.acquire()
            for rule in self.__active_rules:
                e = deepcopy(event)
                if self.evaluateCondition(self.__active_rules[rule]["condition"], e.data):
                    e.setHeaderValue("rule", rule)
                    for queue in self.__active_rules[rule]["queue"]:
                        event_copy = deepcopy(e)
                        for name in queue:
                            if queue[name] is not None:
                                for key, value in queue[name].iteritems():
                                    event_copy.setHeaderValue(key, value)
                                # event_copy["header"][self.name].update(queue[name])
                            self.submit(event_copy, self.pool.getQueue(name))
                else:
                    e.setHeaderValue("rule", rule)
                    self.submit(e, self.pool.queue.nomatch)
                    # self.logging.debug("No match for rule %s." % (rule))
            self.rule_lock.release()
        else:
            raise Exception("Incoming data is not of type dict, dropped.")

    def evaluateCondition(self, conditions, fields):
        for condition in conditions:
            for field in condition:
                if field in fields:
                    if not self.match.do(condition[field], fields[field]):
                        # self.logging.debug("field %s with condition %s DOES NOT MATCH value %s" % (field, condition[field], fields[field]))
                        return False
                    else:
                        pass
                        # self.logging.debug("field %s with condition %s MATCHES value %s" % (field, condition[field], fields[field]))
                else:
                    return False
        return True
Beispiel #48
0
class ConnectionInstance(object):
    def __init__(self, parent, io_loop=None):
        self._parent = parent
        self._closing = False
        self._user_queries = { }
        self._cursor_cache = { }

        self._write_mutex = Semaphore()
        self._socket = None

    def connect(self, timeout):
        with gevent.Timeout(timeout, RqlTimeoutError(self._parent.host, self._parent.port)) as timeout:
            self._socket = SocketWrapper(self)

        # Start a parallel coroutine to perform reads
        gevent.spawn(self._reader)
        return self._parent

    def is_open(self):
        return self._socket is not None and self._socket.is_open()

    def close(self, noreply_wait, token, exception=None):
        self._closing = True
        if exception is not None:
            err_message = "Connection is closed (%s)." % str(exception)
        else:
            err_message = "Connection is closed."

        # Cursors may remove themselves when errored, so copy a list of them
        for cursor in list(self._cursor_cache.values()):
            cursor._error(err_message)

        for query, async_res in iter(self._user_queries.values()):
            async_res.set_exception(RqlDriverError(err_message))

        self._user_queries = { }
        self._cursor_cache = { }

        if noreply_wait:
            noreply = net.Query(pQuery.NOREPLY_WAIT, token, None, None)
            self.run_query(noreply, False)

        try:
            self._socket.close()
        except:
            pass

    # TODO: make connection recoverable if interrupted by a user's gevent.Timeout?
    def run_query(self, query, noreply):
        self._write_mutex.acquire()

        try:
            self._socket.sendall(query.serialize(self._parent._get_json_encoder(query)))
        finally:
            self._write_mutex.release()

        if noreply:
            return None

        async_res = AsyncResult()
        self._user_queries[query.token] = (query, async_res)
        return async_res.get()

    # The _reader coroutine runs in its own coroutine in parallel, reading responses
    # off of the socket and forwarding them to the appropriate AsyncResult or Cursor.
    # This is shut down as a consequence of closing the stream, or an error in the
    # socket/protocol from the server.  Unexpected errors in this coroutine will
    # close the ConnectionInstance and be passed to any open AsyncResult or Cursors.
    def _reader(self):
        try:
            while True:
                buf = self._socket.recvall(12)
                (token, length,) = struct.unpack("<qL", buf)
                buf = self._socket.recvall(length)

                cursor = self._cursor_cache.get(token)
                if cursor is not None:
                    cursor._extend(buf)
                elif token in self._user_queries:
                    # Do not pop the query from the dict until later, so
                    # we don't lose track of it in case of an exception
                    query, async_res = self._user_queries[token]
                    res = net.Response(token, buf, self._parent._get_json_decoder(query))
                    if res.type == pResponse.SUCCESS_ATOM:
                        async_res.set(net.maybe_profile(res.data[0], res))
                    elif res.type in (pResponse.SUCCESS_SEQUENCE,
                                      pResponse.SUCCESS_PARTIAL):
                        cursor = GeventCursor(self, query, res)
                        async_res.set(net.maybe_profile(cursor, res))
                    elif res.type == pResponse.WAIT_COMPLETE:
                        async_res.set(None)
                    else:
                        async_res.set_exception(res.make_error(query))
                    del self._user_queries[token]
                elif not self._closing:
                    raise RqlDriverError("Unexpected response received.")
        except Exception as ex:
            if not self._closing:
                self.close(False, None, ex)
Beispiel #49
0
class PostgresConnectionPool(object):
    """A pool of psycopg2 connections shared between multiple greenlets."""

    @classmethod
    def for_url(cls, url, initial=1, limit=20):
        """Construct a connection pool instance given a URL."""
        from . import ConnectionFailed
        params = parse_database_url(url)
        try:
            db = cls(initial=initial, limit=limit, **params)
        except ConnectionFailed as e:
            raise ConnectionFailed(
                'Failed to connect to %s' % make_safe_url(params), *e.args
            )
        db.url = url  # record the URL for debugging
        return db

    @classmethod
    def for_name(cls, name, initial=1, limit=20):
        """Construct a connection pool instance from the named setting."""
        from ..config import settings
        url = getattr(settings, name)
        return cls.for_url(url, initial=initial, limit=limit)

    def __init__(self, initial=1, limit=20, **settings):
        """Construct a pool of connections.

        settings are passed straight to psycopg2.connect. initial connections
        are opened immediately. More connections may be opened if they are
        required, but at most limit connections may be open at any one time.

        """
        self.settings = settings
        self.sem = Semaphore(limit)
        self.size = 0
        self.pool = []

        for i in xrange(initial):
            self.pool.append(self._connect())

    def _connect(self):
        """Connect to PostgreSQL using the stored connection settings."""
        from . import ConnectionFailed, OperationalError
        try:
            pg = psycopg2.connect(**self.settings)
        except OperationalError as e:
            url = make_safe_url(self.settings)
            raise ConnectionFailed(
                'Failed to connect using %s' % url, *e.args
            )
        self.size += 1
        print "PostgreSQL connection pool size:", self.size
        return pg

    @contextmanager
    def cursor(self):
        """Obtain a cursor from the pool. We reserve the connection exclusively
        for this cursor, as cursors are apparently not safe to be used in multiple
        greenlets.
        """
        with self.connection() as conn:
            yield conn.cursor()

    @contextmanager
    def connection(self):
        """Obtain a connection from the pool, as a context manager, so that the
        connection will eventually be returned to the pool.

        >>> pool = PostgresConnectionPool(**settings)
        >>> with pool.connection() as conn:
        ...     c = conn.cursor()
        ...     c.execute(...)
        ...     conn.commit()
        """
        self.sem.acquire()
        try:
            conn = self.pool.pop(0)
        except IndexError:
            conn = self._connect()

        try:
            yield conn
        except psycopg2.OperationalError:
            # Connection errors should result in the connection being removed
            # from the pool.
            #
            # Unfortunately OperationalError could possibly mean other things and
            # we don't know enough to determine which
            try:
                conn.close()
            finally:
                self.size -= 1
                conn = None
                raise
        except:
            conn.rollback()
            raise
        else:
            conn.commit()
        finally:
            if conn is not None:
                conn.reset()
                self.pool.append(conn)
            self.sem.release()
Beispiel #50
0
class Rotkehlchen(object):
    def __init__(self, args):
        self.lock = Semaphore()
        self.lock.acquire()
        self.results_cache: typing.ResultCache = dict()
        self.connected_exchanges = []

        logfilename = None
        if args.logtarget == 'file':
            logfilename = args.logfile

        if args.loglevel == 'debug':
            loglevel = logging.DEBUG
        elif args.loglevel == 'info':
            loglevel = logging.INFO
        elif args.loglevel == 'warn':
            loglevel = logging.WARN
        elif args.loglevel == 'error':
            loglevel = logging.ERROR
        elif args.loglevel == 'critical':
            loglevel = logging.CRITICAL
        else:
            raise ValueError('Should never get here. Illegal log value')

        logging.basicConfig(
            filename=logfilename,
            filemode='w',
            level=loglevel,
            format='%(asctime)s -- %(levelname)s:%(name)s:%(message)s',
            datefmt='%d/%m/%Y %H:%M:%S %Z',
        )

        if not args.logfromothermodules:
            logging.getLogger('zerorpc').setLevel(logging.CRITICAL)
            logging.getLogger('zerorpc.channel').setLevel(logging.CRITICAL)
            logging.getLogger('urllib3').setLevel(logging.CRITICAL)
            logging.getLogger('urllib3.connectionpool').setLevel(
                logging.CRITICAL)

        self.sleep_secs = args.sleep_secs
        self.data_dir = args.data_dir
        self.args = args
        self.last_data_upload_ts = 0

        self.poloniex = None
        self.kraken = None
        self.bittrex = None
        self.bitmex = None
        self.binance = None

        self.data = DataHandler(self.data_dir)

        self.lock.release()
        self.shutdown_event = gevent.event.Event()

    def initialize_exchanges(self, secret_data):
        # initialize exchanges for which we have keys and are not already initialized
        if self.kraken is None and 'kraken' in secret_data:
            self.kraken = Kraken(
                str.encode(secret_data['kraken']['api_key']),
                str.encode(secret_data['kraken']['api_secret']), self.data_dir)
            self.connected_exchanges.append('kraken')
            self.trades_historian.set_exchange('kraken', self.kraken)

        if self.poloniex is None and 'poloniex' in secret_data:
            self.poloniex = Poloniex(
                str.encode(secret_data['poloniex']['api_key']),
                str.encode(secret_data['poloniex']['api_secret']),
                self.inquirer, self.data_dir)
            self.connected_exchanges.append('poloniex')
            self.trades_historian.set_exchange('poloniex', self.poloniex)

        if self.bittrex is None and 'bittrex' in secret_data:
            self.bittrex = Bittrex(
                str.encode(secret_data['bittrex']['api_key']),
                str.encode(secret_data['bittrex']['api_secret']),
                self.inquirer, self.data_dir)
            self.connected_exchanges.append('bittrex')
            self.trades_historian.set_exchange('bittrex', self.bittrex)

        if self.binance is None and 'binance' in secret_data:
            self.binance = Binance(
                str.encode(secret_data['binance']['api_key']),
                str.encode(secret_data['binance']['api_secret']),
                self.inquirer, self.data_dir)
            self.connected_exchanges.append('binance')
            self.trades_historian.set_exchange('binance', self.binance)

        if self.bitmex is None and 'bitmex' in secret_data:
            self.bitmex = Bitmex(
                str.encode(secret_data['bitmex']['api_key']),
                str.encode(secret_data['bitmex']['api_secret']), self.inquirer,
                self.data_dir)
            self.connected_exchanges.append('bitmex')
            self.trades_historian.set_exchange('bitmex', self.bitmex)

    def try_premium_at_start(self, api_key, api_secret, create_new,
                             sync_approval, user_dir):
        """Check if new user provided api pair or we already got one in the DB"""

        if api_key != '':
            self.premium, valid, empty_or_error = premium_create_and_verify(
                api_key, api_secret)
            if not valid:
                log.error('Given API key is invalid')
                # At this point we are at a new user trying to create an account with
                # premium API keys and we failed. But a directory was created. Remove it.
                shutil.rmtree(user_dir)
                raise AuthenticationError(
                    'Could not verify keys for the new account. '
                    '{}'.format(empty_or_error))
        else:
            # If we got premium initialize it and try to sync with the server
            premium_credentials = self.data.db.get_rotkehlchen_premium()
            if premium_credentials:
                api_key = premium_credentials[0]
                api_secret = premium_credentials[1]
                self.premium, valid, empty_or_error = premium_create_and_verify(
                    api_key, api_secret)
                if not valid:
                    log.error(
                        'The API keys found in the Database are not valid. Perhaps '
                        'they expired?')
                del self.premium
                return
            else:
                # no premium credentials in the DB
                return

        if self.can_sync_data_from_server():
            if sync_approval == 'unknown' and not create_new:
                log.info('DB data at server newer than local')
                raise PermissionError(
                    'Rotkehlchen Server has newer version of your DB data. '
                    'Should we replace local data with the server\'s?')
            elif sync_approval == 'yes' or sync_approval == 'unknown' and create_new:
                log.info('User approved data sync from server')
                if self.sync_data_from_server():
                    if create_new:
                        # if we successfully synced data from the server and this is
                        # a new account, make sure the api keys are properly stored
                        # in the DB
                        self.data.db.set_rotkehlchen_premium(
                            api_key, api_secret)
            else:
                log.debug('Could sync data from server but user refused')

    def unlock_user(self, user, password, create_new, sync_approval, api_key,
                    api_secret):
        log.info(
            'Unlocking user',
            user=user,
            create_new=create_new,
            sync_approval=sync_approval,
        )
        # unlock or create the DB
        self.password = password
        user_dir = self.data.unlock(user, password, create_new)
        self.try_premium_at_start(api_key, api_secret, create_new,
                                  sync_approval, user_dir)

        secret_data = self.data.db.get_exchange_secrets()
        settings = self.data.db.get_settings()
        historical_data_start = settings['historical_data_start']
        eth_rpc_port = settings['eth_rpc_port']
        self.trades_historian = TradesHistorian(
            self.data_dir,
            self.data.db,
            self.data.get_eth_accounts(),
            historical_data_start,
        )
        self.price_historian = PriceHistorian(
            self.data_dir,
            historical_data_start,
        )
        db_settings = self.data.db.get_settings()
        self.accountant = Accountant(
            price_historian=self.price_historian,
            profit_currency=self.data.main_currency(),
            user_directory=user_dir,
            create_csv=True,
            ignored_assets=self.data.db.get_ignored_assets(),
            include_crypto2crypto=db_settings['include_crypto2crypto'],
            taxfree_after_period=db_settings['taxfree_after_period'],
        )

        # Initialize the rotkehlchen logger
        LoggingSettings(anonymized_logs=db_settings['anonymized_logs'])

        self.inquirer = Inquirer(kraken=self.kraken)
        self.initialize_exchanges(secret_data)

        ethchain = Ethchain(eth_rpc_port)
        self.blockchain = Blockchain(
            blockchain_accounts=self.data.db.get_blockchain_accounts(),
            all_eth_tokens=self.data.eth_tokens,
            owned_eth_tokens=self.data.db.get_owned_tokens(),
            inquirer=self.inquirer,
            ethchain=ethchain,
        )

    def set_premium_credentials(self, api_key, api_secret):
        log.info('Setting new premium credentials')
        if hasattr(self, 'premium'):
            valid, empty_or_error = self.premium.set_credentials(
                api_key, api_secret)
        else:
            self.premium, valid, empty_or_error = premium_create_and_verify(
                api_key, api_secret)

        if valid:
            self.data.set_premium_credentials(api_key, api_secret)
            return True, ''
        log.error('Setting new premium credentials failed',
                  error=empty_or_error)
        return False, empty_or_error

    def maybe_upload_data_to_server(self):
        # upload only if unlocked user has premium
        if not hasattr(self, 'premium'):
            return

        # upload only once per hour
        diff = ts_now() - self.last_data_upload_ts
        if diff > 3600:
            self.upload_data_to_server()

    def upload_data_to_server(self):
        log.debug('upload to server -- start')
        data, our_hash = self.data.compress_and_encrypt_db(self.password)
        success, result_or_error = self.premium.query_last_data_metadata()
        if not success:
            log.debug(
                'upload to server -- query last metadata failed',
                error=result_or_error,
            )
            return

        log.debug(
            'CAN_PUSH',
            ours=our_hash,
            theirs=result_or_error['data_hash'],
        )
        if our_hash == result_or_error['data_hash']:
            log.debug('upload to server -- same hash')
            # same hash -- no need to upload anything
            return

        our_last_write_ts = self.data.db.get_last_write_ts()
        if our_last_write_ts <= result_or_error['last_modify_ts']:
            # Server's DB was modified after our local DB
            log.debug("CAN_PUSH -> 3")
            log.debug('upload to server -- remote db more recent than local')
            return

        success, result_or_error = self.premium.upload_data(
            data, our_hash, our_last_write_ts, 'zlib')
        if not success:
            log.debug('upload to server -- upload error',
                      error=result_or_error)
            return

        self.last_data_upload_ts = ts_now()
        log.debug('upload to server -- success')

    def can_sync_data_from_server(self):
        log.debug('sync data from server -- start')
        data, our_hash = self.data.compress_and_encrypt_db(self.password)
        success, result_or_error = self.premium.query_last_data_metadata()
        if not success:
            log.debug('sync data from server failed', error=result_or_error)
            return False

        log.debug(
            'CAN_PULL',
            ours=our_hash,
            theirs=result_or_error['data_hash'],
        )
        if our_hash == result_or_error['data_hash']:
            log.debug('sync from server -- same hash')
            # same hash -- no need to get anything
            return False

        our_last_write_ts = self.data.db.get_last_write_ts()
        if our_last_write_ts >= result_or_error['last_modify_ts']:
            # Local DB is newer than Server DB
            log.debug('sync from server -- local DB more recent than remote')
            return False

        return True

    def sync_data_from_server(self):
        success, error_or_result = self.premium.pull_data()
        if not success:
            log.debug('sync from server -- pulling failed.',
                      error=error_or_result)
            return False

        self.data.decompress_and_decrypt_db(self.password,
                                            error_or_result['data'])
        return True

    def start(self):
        return gevent.spawn(self.main_loop)

    def main_loop(self):
        while True and not self.shutdown_event.is_set():
            log.debug('Main loop start')
            if self.poloniex is not None:
                self.poloniex.main_logic()
            if self.kraken is not None:
                self.kraken.main_logic()

            self.maybe_upload_data_to_server()

            log.debug('Main loop end')
            gevent.sleep(MAIN_LOOP_SECS_DELAY)

    def add_blockchain_account(self, blockchain, account):
        try:
            new_data = self.blockchain.add_blockchain_account(
                blockchain, account)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))
        self.data.add_blockchain_account(blockchain, account)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def remove_blockchain_account(self, blockchain, account):
        try:
            new_data = self.blockchain.remove_blockchain_account(
                blockchain, account)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))
        self.data.remove_blockchain_account(blockchain, account)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def add_owned_eth_tokens(self, tokens):
        try:
            new_data = self.blockchain.track_new_tokens(tokens)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))

        self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def remove_owned_eth_tokens(self, tokens):
        try:
            new_data = self.blockchain.remove_eth_tokens(tokens)
        except InputError as e:
            return simple_result(False, str(e))
        self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def process_history(self, start_ts, end_ts):
        (
            error_or_empty, history, margin_history, loan_history,
            asset_movements, eth_transactions
        ) = self.trades_historian.get_history(
            start_ts=
            0,  # For entire history processing we need to have full history available
            end_ts=ts_now(),
            end_at_least_ts=end_ts)
        result = self.accountant.process_history(start_ts, end_ts, history,
                                                 margin_history, loan_history,
                                                 asset_movements,
                                                 eth_transactions)
        return result, error_or_empty

    def query_fiat_balances(self):
        log.info('query_fiat_balances called')
        result = {}
        balances = self.data.get_fiat_balances()
        for currency, amount in balances.items():
            amount = FVal(amount)
            usd_rate = query_fiat_pair(currency, 'USD')
            result[currency] = {
                'amount': amount,
                'usd_value': amount * usd_rate
            }

        return result

    def query_balances(self, requested_save_data=False):
        log.info('query_balances called',
                 requested_save_data=requested_save_data)

        balances = {}
        problem_free = True
        for exchange in self.connected_exchanges:
            exchange_balances, msg = getattr(self, exchange).query_balances()
            # If we got an error, disregard that exchange but make sure we don't save data
            if not exchange_balances:
                problem_free = False
            else:
                balances[exchange] = exchange_balances

        result, error_or_empty = self.blockchain.query_balances()
        if error_or_empty == '':
            balances['blockchain'] = result['totals']
        else:
            problem_free = False

        result = self.query_fiat_balances()
        if result != {}:
            balances['banks'] = result

        combined = combine_stat_dicts([v for k, v in balances.items()])
        total_usd_per_location = [(k, dict_get_sumof(v, 'usd_value'))
                                  for k, v in balances.items()]

        # calculate net usd value
        net_usd = FVal(0)
        for k, v in combined.items():
            net_usd += FVal(v['usd_value'])

        stats = {'location': {}, 'net_usd': net_usd}
        for entry in total_usd_per_location:
            name = entry[0]
            total = entry[1]
            if net_usd != FVal(0):
                percentage = (total / net_usd).to_percentage()
            else:
                percentage = '0%'
            stats['location'][name] = {
                'usd_value': total,
                'percentage_of_net_value': percentage,
            }

        for k, v in combined.items():
            if net_usd != FVal(0):
                percentage = (v['usd_value'] / net_usd).to_percentage()
            else:
                percentage = '0%'
            combined[k]['percentage_of_net_value'] = percentage

        result_dict = merge_dicts(combined, stats)

        allowed_to_save = requested_save_data or self.data.should_save_balances(
        )
        if problem_free and allowed_to_save:
            self.data.save_balances_data(result_dict)

        # After adding it to the saved file we can overlay additional data that
        # is not required to be saved in the history file
        try:
            details = self.data.accountant.details
            for asset, (tax_free_amount, average_buy_value) in details.items():
                if asset not in result_dict:
                    continue

                result_dict[asset]['tax_free_amount'] = tax_free_amount
                result_dict[asset]['average_buy_value'] = average_buy_value

                current_price = result_dict[asset]['usd_value'] / result_dict[
                    asset]['amount']
                if average_buy_value != FVal(0):
                    result_dict[asset]['percent_change'] = (
                        ((current_price - average_buy_value) /
                         average_buy_value) * 100)
                else:
                    result_dict[asset]['percent_change'] = 'INF'

        except AttributeError:
            pass

        return result_dict

    def set_main_currency(self, currency):
        with self.lock:
            self.data.set_main_currency(currency, self.accountant)
            if currency != 'USD':
                self.usd_to_main_currency_rate = query_fiat_pair(
                    'USD', currency)

    def set_settings(self, settings):
        log.info('Add new settings')

        message = ''
        with self.lock:
            if 'eth_rpc_port' in settings:
                result, msg = self.blockchain.set_eth_rpc_port(
                    settings['eth_rpc_port'])
                if not result:
                    # Don't save it in the DB
                    del settings['eth_rpc_port']
                    message += "\nEthereum RPC port not set: " + msg

            if 'main_currency' in settings:
                main_currency = settings['main_currency']
                if main_currency != 'USD':
                    self.usd_to_main_currency_rate = query_fiat_pair(
                        'USD', main_currency)

            res, msg = self.accountant.customize(settings)
            if not res:
                message += '\n' + msg
                return False, message

            _, msg, = self.data.set_settings(settings, self.accountant)
            if msg != '':
                message += '\n' + msg

            # Always return success here but with a message
            return True, message

    def usd_to_main_currency(self, amount):
        main_currency = self.data.main_currency()
        if main_currency != 'USD' and not hasattr(self,
                                                  'usd_to_main_currency_rate'):
            self.usd_to_main_currency_rate = query_fiat_pair(
                'USD', main_currency)

        return self.usd_to_main_currency_rate * amount

    def setup_exchange(self, name, api_key, api_secret):
        log.info('setup_exchange', name=name)
        if name not in SUPPORTED_EXCHANGES:
            return False, 'Attempted to register unsupported exchange {}'.format(
                name)

        if getattr(self, name) is not None:
            return False, 'Exchange {} is already registered'.format(name)

        secret_data = {}
        secret_data[name] = {
            'api_key': api_key,
            'api_secret': api_secret,
        }
        self.initialize_exchanges(secret_data)

        exchange = getattr(self, name)
        result, message = exchange.validate_api_key()
        if not result:
            log.error(
                'Failed to validate API key for exchange',
                name=name,
                error=message,
            )
            self.delete_exchange_data(name)
            return False, message

        # Success, save the result in the DB
        self.data.db.add_exchange(name, api_key, api_secret)
        return True, ''

    def delete_exchange_data(self, name):
        self.connected_exchanges.remove(name)
        self.trades_historian.set_exchange(name, None)
        delattr(self, name)
        setattr(self, name, None)

    def remove_exchange(self, name):
        if getattr(self, name) is None:
            return False, 'Exchange {} is not registered'.format(name)

        self.delete_exchange_data(name)
        # Success, remove it also from the DB
        self.data.db.remove_exchange(name)
        return True, ''

    def shutdown(self):
        log.info("Shutting Down")
        self.shutdown_event.set()
Beispiel #51
0
class Queue(Greenlet):
    """Manages the queue of |Envelope| objects waiting for delivery. This is
    not a standard FIFO queue, a message's place in the queue depends entirely
    on the timestamp of its next delivery attempt.

    :param store: Object implementing :class:`QueueStorage`.
    :param relay: |Relay| object used to attempt message deliveries. If this
                  is not given, no deliveries will be attempted on received
                  messages.
    :param backoff: Function that, given an |Envelope| and number of delivery
                    attempts, will return the number of seconds before the next
                    attempt. If it returns ``None``, the message will be
                    permanently failed. The default backoff function simply
                    returns ``None`` and messages are never retried.
    :param bounce_factory: Function that produces a |Bounce| object given the
                           same parameters as the |Bounce| constructor. If the
                           function returns ``None``, no bounce is delivered.
                           By default, a new |Bounce| is created in every case.
    :param bounce_queue: |Queue| object that will be used for delivering bounce
                         messages. The default is ``self``.
    :param store_pool: Number of simultaneous operations performable against
                       the ``store`` object. Default is unlimited.
    :param relay_pool: Number of simultaneous operations performable against
                       the ``relay`` object. Default is unlimited.

    """

    def __init__(self, store, relay=None, backoff=None, bounce_factory=None,
                 bounce_queue=None, store_pool=None, relay_pool=None):
        super(Queue, self).__init__()
        self.store = store
        self.relay = relay
        self.backoff = backoff or self._default_backoff
        self.bounce_factory = bounce_factory or Bounce
        self.bounce_queue = bounce_queue or self
        self.wake = Event()
        self.queued = []
        self.active_ids = set()
        self.queued_ids = set()
        self.queued_lock = Semaphore(1)
        self.queue_policies = []
        self._use_pool('store_pool', store_pool)
        self._use_pool('relay_pool', relay_pool)

    def add_policy(self, policy):
        """Adds a |QueuePolicy| to be executed before messages are persisted
        to storage.

        :param policy: |QueuePolicy| object to execute.

        """
        if isinstance(policy, QueuePolicy):
            self.queue_policies.append(policy)
        else:
            raise TypeError('Argument not a QueuePolicy.')

    @staticmethod
    def _default_backoff(envelope, attempts):
        pass

    def _run_policies(self, envelope):
        results = [envelope]

        def recurse(current, i):
            try:
                policy = self.queue_policies[i]
            except IndexError:
                return
            ret = policy.apply(current)
            if ret:
                results.remove(current)
                results.extend(ret)
                for env in ret:
                    recurse(env, i+1)
            else:
                recurse(current, i+1)
        recurse(envelope, 0)
        return results

    def _use_pool(self, attr, pool):
        if pool is None:
            pass
        elif isinstance(pool, Pool):
            setattr(self, attr, pool)
        else:
            setattr(self, attr, Pool(pool))

    def _pool_run(self, which, func, *args, **kwargs):
        pool = getattr(self, which+'_pool', None)
        if pool:
            ret = pool.spawn(func, *args, **kwargs)
            return ret.get()
        else:
            return func(*args, **kwargs)

    def _pool_imap(self, which, func, *iterables):
        pool = getattr(self, which+'_pool', gevent)
        threads = imap(pool.spawn, repeat(func), *iterables)
        ret = []
        for thread in threads:
            thread.join()
            ret.append(thread.exception or thread.value)
        return ret

    def _pool_spawn(self, which, func, *args, **kwargs):
        pool = getattr(self, which+'_pool', gevent)
        return pool.spawn(func, *args, **kwargs)

    def _add_queued(self, entry):
        timestamp, id = entry
        if id not in self.queued_ids | self.active_ids:
            bisect.insort(self.queued, entry)
            self.queued_ids.add(id)
            self.wake.set()

    def enqueue(self, envelope):
        """Drops a new message in the queue for delivery. The first delivery
        attempt is made immediately (depending on relay pool availability).
        This method is not typically called directly, |Edge| objects use it
        when they receive new messages.

        :param envelope: |Envelope| object to enqueue.
        :returns: Zipped list of envelopes and their respective queue IDs (or
                  thrown :exc:`QueueError` objects).

        """
        now = time.time()
        envelopes = self._run_policies(envelope)
        ids = self._pool_imap('store', self.store.write, envelopes,
                              repeat(now))
        results = list(zip(envelopes, ids))
        for env, id in results:
            if not isinstance(id, BaseException):
                if self.relay and id not in self.active_ids:
                    self.active_ids.add(id)
                    self._pool_spawn('relay', self._attempt, id, env, 0)
            elif not isinstance(id, QueueError):
                raise id  # Re-raise exceptions that are not QueueError.
        return results

    def _load_all(self):
        for entry in self.store.load():
            self._add_queued(entry)

    def _remove(self, id):
        self._pool_spawn('store', self.store.remove, id)
        self.queued_ids.discard(id)
        self.active_ids.discard(id)

    def _bounce(self, envelope, reply):
        bounce = self.bounce_factory(envelope, reply)
        if bounce:
            return self.bounce_queue.enqueue(bounce)

    def _perm_fail(self, id, envelope, reply):
        if id is not None:
            self._remove(id)
        if envelope.sender:  # Can't bounce to null-sender.
            self._pool_spawn('bounce', self._bounce, envelope, reply)

    def _split_by_reply(self, envelope, replies):
        if isinstance(replies, Reply):
            return [(replies, envelope)]
        groups = []
        for i, rcpt in enumerate(envelope.recipients):
            for reply, group_env in groups:
                if replies[i] == reply:
                    group_env.recipients.append(rcpt)
                    break
            else:
                group_env = envelope.copy([rcpt])
                groups.append((replies[i], group_env))
        return groups

    def _retry_later(self, id, envelope, replies):
        attempts = self.store.increment_attempts(id)
        wait = self.backoff(envelope, attempts)
        if wait is None:
            for reply, group_env in self._split_by_reply(envelope, replies):
                reply.message += ' (Too many retries)'
                self._perm_fail(None, group_env, reply)
            self._remove(id)
            return False
        else:
            when = time.time() + wait
            self.store.set_timestamp(id, when)
            self.active_ids.discard(id)
            self._add_queued((when, id))
            return True

    def _attempt(self, id, envelope, attempts):
        try:
            results = self.relay._attempt(envelope, attempts)
        except TransientRelayError as e:
            self._pool_spawn('store', self._retry_later, id, envelope, e.reply)
        except PermanentRelayError as e:
            self._perm_fail(id, envelope, e.reply)
        except Exception as e:
            logging.log_exception(__name__)
            reply = Reply('450', '4.0.0 Unhandled delivery error: '+str(e))
            self._pool_spawn('store', self._retry_later, id, envelope, reply)
            raise
        else:
            if isinstance(results, collections.Mapping):
                self._handle_partial_relay(id, envelope, attempts, results)
            elif isinstance(results, collections.Sequence):
                results = dict(zip(envelope.recipients, results))
                self._handle_partial_relay(id, envelope, attempts, results)
            else:
                self._remove(id)

    def _handle_partial_relay(self, id, envelope, attempts, results):
        delivered = set()
        tempfails = []
        permfails = []
        for rcpt, rcpt_res in results.items():
            if rcpt_res is None or isinstance(rcpt_res, Reply):
                delivered.add(envelope.recipients.index(rcpt))
            elif isinstance(rcpt_res, PermanentRelayError):
                delivered.add(envelope.recipients.index(rcpt))
                permfails.append((rcpt, rcpt_res.reply))
            elif isinstance(rcpt_res, TransientRelayError):
                tempfails.append((rcpt, rcpt_res.reply))
        if permfails:
            rcpts, replies = zip(*permfails)
            fail_env = envelope.copy(rcpts)
            for reply, group_env in self._split_by_reply(fail_env, replies):
                self._perm_fail(None, group_env, reply)
        if tempfails:
            rcpts, replies = zip(*tempfails)
            fail_env = envelope.copy(rcpts)
            if not self._retry_later(id, fail_env, replies):
                return
        else:
            self.store.remove(id)
            return
        self.store.set_recipients_delivered(id, delivered)

    def _dequeue(self, id):
        try:
            envelope, attempts = self.store.get(id)
        except KeyError:
            return
        if id not in self.active_ids:
            self.active_ids.add(id)
            self._pool_spawn('relay', self._attempt, id, envelope, attempts)

    def _check_ready(self, now):
        last_i = 0
        for i, entry in enumerate(self.queued):
            timestamp, entry_id = entry
            if now >= timestamp:
                self._pool_spawn('store', self._dequeue, entry_id)
                last_i = i+1
            else:
                break
        if last_i > 0:
            self.queued = self.queued[last_i:]
            self.queued_ids = set([id for _, id in self.queued])

    def _wait_store(self):
        while True:
            try:
                for entry in self.store.wait():
                    self._add_queued(entry)
            except NotImplementedError:
                return

    def _wait_ready(self, now):
        try:
            first = self.queued[0]
        except IndexError:
            self.wake.wait()
            self.wake.clear()
            return
        first_timestamp = first[0]
        if first_timestamp > now:
            self.wake.wait(first_timestamp-now)
            self.wake.clear()

    def flush(self):
        """Attempts to immediately flush all messages waiting in the queue,
        regardless of their retry timers.

        .. warning::

           This can be a very expensive operation, use with care.

        """
        self.wake.set()
        self.wake.clear()
        self.queued_lock.acquire()
        try:
            for entry in self.queued:
                self._pool_spawn('store', self._dequeue, entry[1])
            self.queued = []
        finally:
            self.queued_lock.release()

    def kill(self):
        """This method is used by |Queue| and |Queue|-like objects to properly
        end any associated services (such as running :class:`~gevent.Greenlet`
        threads) and close resources.

        """
        super(Queue, self).kill()

    def _run(self):
        if not self.relay:
            return
        self._pool_spawn('store', self._load_all)
        self._pool_spawn('store', self._wait_store)
        while True:
            self.queued_lock.acquire()
            try:
                now = time.time()
                self._check_ready(now)
                self._wait_ready(now)
            finally:
                self.queued_lock.release()
Beispiel #52
0
class TarantoolInspector(StreamServer):
    """
    Tarantool inspector daemon. Usage:
    inspector = TarantoolInspector('localhost', 8080)
    inspector.start()
    # run some tests
    inspector.stop()
    """

    def __init__(self, host, port):
        # When specific port range was acquired for current worker, don't allow
        # OS set port for us that isn't from specified range.
        if port == 0:
            port = find_port()
        super(TarantoolInspector, self).__init__((host, port))
        self.parser = None

    def start(self):
        super(TarantoolInspector, self).start()
        os.environ['INSPECTOR'] = str(self.server_port)

    def stop(self):
        del os.environ['INSPECTOR']

    def set_parser(self, parser):
        self.parser = parser
        self.sem = Semaphore()

    @staticmethod
    def readline(socket, delimiter='\n', size=4096):
        result = ''
        data = True

        while data:
            try:
                data = socket.recv(size)
            except IOError:
                # catch instance halt connection refused errors
                data = ''
            result += data

            while result.find(delimiter) != -1:
                line, result = result.split(delimiter, 1)
                yield line
        return

    def handle(self, socket, addr):
        if self.parser is None:
            raise AttributeError('Parser is not defined')
        self.sem.acquire()

        for line in self.readline(socket):
            try:
                result = self.parser.parse_preprocessor(line)
            except (KeyboardInterrupt, TarantoolStartError):
                # propagate to the main greenlet
                raise
            except Exception as e:
                self.parser.kill_current_test()
                color_stdout('\nTarantoolInpector.handle() received the following error:\n' +
                    traceback.format_exc() + '\n', schema='error')
                result = { "error": repr(e) }
            if result == None:
                result = True
            result = yaml.dump(result)
            if not result.endswith('...\n'):
                result = result + '...\n'
            socket.sendall(result)

        self.sem.release()

    def cleanup_nondefault(self):
        if self.parser:
            self.parser.cleanup_nondefault()
Beispiel #53
0
class ServiceBase(object):
    def __init__(self, s_id, conf):
        self.s_id = s_id
        self.conf = conf
        self.start_time = None
        self.finish_time = None
        self.puppet = None
        self.lock = Semaphore()
        self.queue = set()
        self._status = "Unknown"
        self.started = False

    def start(self):
        self.lock.acquire()
        if self.started:
            self.lock.release()
            return
        # print "Starting service", self.s_id
        self.start_time = time.time()
        self._status = "Booting"
        self.real_start()
        self.started = True
        self.lock.release()

    def terminate(self):
        if not self.puppet:
            return
        self.real_terminate()
        self._status = "Terminated"
        self.puppet = None
        self.finish_time = time.time()

    def real_start(self):
        raise NotImplementedError

    def real_terminate(self):
        raise NotImplementedError

    def record_task(self, task):
        self.queue.add(task.tid)

    def run(self, tid, task, *argv, **kwargs):
        task_info = Husky.dumps((task, argv, kwargs))
        self.puppet.submit_task(tid, task_info)
        res = self.puppet.fetch_result(tid)
        self.queue.discard(tid)
        return Husky.loads(res)

    def __getattr__(self, item):
        try:
            if item == "tasks":
                return list(self.queue)
            elif item == "status":
                if self.puppet:
                    return self.puppet.get_attr("status")
                else:
                    return self._status
            elif item in ["cpu", "memory"]:
                if self.puppet:
                    return self.puppet.get_attr(item)
                else:
                    return None
        except:
            return "Unknown"

    def __repr__(self):
        return "%s-%d" % (self.conf, self.s_id)
Beispiel #54
0
class Rotkehlchen():
    def __init__(self, args):
        self.lock = Semaphore()
        self.lock.acquire()
        self.results_cache: ResultCache = dict()
        self.premium = None
        self.connected_exchanges = []
        self.user_is_logged_in = False

        logfilename = None
        if args.logtarget == 'file':
            logfilename = args.logfile

        if args.loglevel == 'debug':
            loglevel = logging.DEBUG
        elif args.loglevel == 'info':
            loglevel = logging.INFO
        elif args.loglevel == 'warn':
            loglevel = logging.WARN
        elif args.loglevel == 'error':
            loglevel = logging.ERROR
        elif args.loglevel == 'critical':
            loglevel = logging.CRITICAL
        else:
            raise ValueError('Should never get here. Illegal log value')

        logging.basicConfig(
            filename=logfilename,
            filemode='w',
            level=loglevel,
            format='%(asctime)s -- %(levelname)s:%(name)s:%(message)s',
            datefmt='%d/%m/%Y %H:%M:%S %Z',
        )

        if not args.logfromothermodules:
            logging.getLogger('zerorpc').setLevel(logging.CRITICAL)
            logging.getLogger('zerorpc.channel').setLevel(logging.CRITICAL)
            logging.getLogger('urllib3').setLevel(logging.CRITICAL)
            logging.getLogger('urllib3.connectionpool').setLevel(
                logging.CRITICAL)

        self.sleep_secs = args.sleep_secs
        self.data_dir = args.data_dir
        self.args = args

        self.poloniex = None
        self.kraken = None
        self.bittrex = None
        self.bitmex = None
        self.binance = None

        self.msg_aggregator = MessagesAggregator()
        self.data = DataHandler(self.data_dir, self.msg_aggregator)
        # Initialize the Inquirer singleton
        Inquirer(data_dir=self.data_dir)

        self.lock.release()
        self.shutdown_event = gevent.event.Event()

    def initialize_exchanges(
            self, exchange_credentials: Dict[str, ApiCredentials]) -> None:
        # initialize exchanges for which we have keys and are not already initialized
        if self.kraken is None and 'kraken' in exchange_credentials:
            self.kraken = Kraken(
                api_key=exchange_credentials['kraken'].api_key,
                secret=exchange_credentials['kraken'].api_secret,
                user_directory=self.user_directory,
                msg_aggregator=self.msg_aggregator,
                usd_eur_price=Inquirer().query_fiat_pair(S_EUR, S_USD),
            )
            self.connected_exchanges.append('kraken')
            self.trades_historian.set_exchange('kraken', self.kraken)

        if self.poloniex is None and 'poloniex' in exchange_credentials:
            self.poloniex = Poloniex(
                api_key=exchange_credentials['poloniex'].api_key,
                secret=exchange_credentials['poloniex'].api_secret,
                user_directory=self.user_directory,
                msg_aggregator=self.msg_aggregator,
            )
            self.connected_exchanges.append('poloniex')
            self.trades_historian.set_exchange('poloniex', self.poloniex)

        if self.bittrex is None and 'bittrex' in exchange_credentials:
            self.bittrex = Bittrex(
                api_key=exchange_credentials['bittrex'].api_key,
                secret=exchange_credentials['bittrex'].api_secret,
                user_directory=self.user_directory,
                msg_aggregator=self.msg_aggregator,
            )
            self.connected_exchanges.append('bittrex')
            self.trades_historian.set_exchange('bittrex', self.bittrex)

        if self.binance is None and 'binance' in exchange_credentials:
            self.binance = Binance(
                api_key=exchange_credentials['binance'].api_key,
                secret=exchange_credentials['binance'].api_secret,
                data_dir=self.user_directory,
                msg_aggregator=self.msg_aggregator,
            )
            self.connected_exchanges.append('binance')
            self.trades_historian.set_exchange('binance', self.binance)

        if self.bitmex is None and 'bitmex' in exchange_credentials:
            self.bitmex = Bitmex(
                api_key=exchange_credentials['bitmex'].api_key,
                secret=exchange_credentials['bitmex'].api_secret,
                user_directory=self.user_directory,
            )
            self.connected_exchanges.append('bitmex')
            self.trades_historian.set_exchange('bitmex', self.bitmex)

    def remove_all_exchanges(self):
        if self.kraken is not None:
            self.delete_exchange_data('kraken')
        if self.poloniex is not None:
            self.delete_exchange_data('poloniex')
        if self.bittrex is not None:
            self.delete_exchange_data('bittrex')
        if self.binance is not None:
            self.delete_exchange_data('binance')
        if self.bitmex is not None:
            self.delete_exchange_data('bitmex')

    def unlock_user(
        self,
        user: str,
        password: str,
        create_new: bool,
        sync_approval: bool,
        api_key: ApiKey,
        api_secret: ApiSecret,
    ) -> None:
        """Unlocks an existing user or creates a new one if `create_new` is True"""
        log.info(
            'Unlocking user',
            user=user,
            create_new=create_new,
            sync_approval=sync_approval,
        )
        # unlock or create the DB
        self.password = password
        self.user_directory = self.data.unlock(user, password, create_new)
        self.last_data_upload_ts = self.data.db.get_last_data_upload_ts()
        self.premium_sync_manager = PremiumSyncManager(data=self.data,
                                                       password=password)
        try:
            self.premium = self.premium_sync_manager.try_premium_at_start(
                api_key=api_key,
                api_secret=api_secret,
                username=user,
                create_new=create_new,
                sync_approval=sync_approval,
            )
        except AuthenticationError:
            # It means that our credentials were not accepted by the server
            # or some other error happened
            pass

        exchange_credentials = self.data.db.get_exchange_credentials()
        settings = self.data.db.get_settings()
        historical_data_start = settings['historical_data_start']
        eth_rpc_endpoint = settings['eth_rpc_endpoint']
        self.trades_historian = TradesHistorian(
            user_directory=self.user_directory,
            db=self.data.db,
            eth_accounts=self.data.get_eth_accounts(),
            msg_aggregator=self.msg_aggregator,
        )
        # Initialize the price historian singleton
        PriceHistorian(
            data_directory=self.data_dir,
            history_date_start=historical_data_start,
            cryptocompare=Cryptocompare(data_directory=self.data_dir),
        )
        db_settings = self.data.db.get_settings()
        self.accountant = Accountant(
            profit_currency=self.data.main_currency(),
            user_directory=self.user_directory,
            msg_aggregator=self.msg_aggregator,
            create_csv=True,
            ignored_assets=self.data.db.get_ignored_assets(),
            include_crypto2crypto=db_settings['include_crypto2crypto'],
            taxfree_after_period=db_settings['taxfree_after_period'],
            include_gas_costs=db_settings['include_gas_costs'],
        )

        # Initialize the rotkehlchen logger
        LoggingSettings(anonymized_logs=db_settings['anonymized_logs'])
        self.initialize_exchanges(exchange_credentials)

        ethchain = Ethchain(eth_rpc_endpoint)
        self.blockchain = Blockchain(
            blockchain_accounts=self.data.db.get_blockchain_accounts(),
            owned_eth_tokens=self.data.db.get_owned_tokens(),
            ethchain=ethchain,
            msg_aggregator=self.msg_aggregator,
        )
        self.user_is_logged_in = True

    def logout(self) -> None:
        if not self.user_is_logged_in:
            return

        user = self.data.username
        log.info(
            'Logging out user',
            user=user,
        )
        del self.blockchain
        self.remove_all_exchanges()

        # Reset rotkehlchen logger to default
        LoggingSettings(anonymized_logs=DEFAULT_ANONYMIZED_LOGS)

        del self.accountant
        del self.trades_historian

        if self.premium is not None:
            del self.premium
        self.data.logout()
        self.password = ''

        self.user_is_logged_in = False
        log.info(
            'User successfully logged out',
            user=user,
        )

    def set_premium_credentials(self, api_key: ApiKey,
                                api_secret: ApiSecret) -> None:
        """
        Raises IncorrectApiKeyFormat if the given key is not in a proper format
        Raises AuthenticationError if the given key is rejected by the Rotkehlchen server
        """
        log.info('Setting new premium credentials')

        if self.premium is not None:
            self.premium.set_credentials(api_key, api_secret)
        else:
            self.premium = premium_create_and_verify(api_key, api_secret)

        self.data.set_premium_credentials(api_key, api_secret)

    def start(self):
        return gevent.spawn(self.main_loop)

    def main_loop(self):
        while self.shutdown_event.wait(MAIN_LOOP_SECS_DELAY) is not True:
            log.debug('Main loop start')
            if self.poloniex is not None:
                self.poloniex.main_logic()
            if self.kraken is not None:
                self.kraken.main_logic()

            self.premium_sync_manager.maybe_upload_data_to_server()

            log.debug('Main loop end')

    def add_blockchain_account(
        self,
        blockchain: SupportedBlockchain,
        account: BlockchainAddress,
    ) -> Dict:
        try:
            new_data = self.blockchain.add_blockchain_account(
                blockchain, account)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))
        self.data.add_blockchain_account(blockchain, account)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def remove_blockchain_account(
        self,
        blockchain: SupportedBlockchain,
        account: BlockchainAddress,
    ):
        try:
            new_data = self.blockchain.remove_blockchain_account(
                blockchain, account)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))
        self.data.remove_blockchain_account(blockchain, account)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def add_owned_eth_tokens(self, tokens: List[str]):
        ethereum_tokens = [
            EthereumToken(identifier=identifier) for identifier in tokens
        ]
        try:
            new_data = self.blockchain.track_new_tokens(ethereum_tokens)
        except (InputError, EthSyncError) as e:
            return simple_result(False, str(e))

        self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def remove_owned_eth_tokens(self, tokens: List[str]):
        ethereum_tokens = [
            EthereumToken(identifier=identifier) for identifier in tokens
        ]
        try:
            new_data = self.blockchain.remove_eth_tokens(ethereum_tokens)
        except InputError as e:
            return simple_result(False, str(e))
        self.data.write_owned_eth_tokens(self.blockchain.owned_eth_tokens)
        return accounts_result(new_data['per_account'], new_data['totals'])

    def process_history(self, start_ts, end_ts):
        (
            error_or_empty,
            history,
            margin_history,
            loan_history,
            asset_movements,
            eth_transactions,
        ) = self.trades_historian.get_history(
            start_ts=
            0,  # For entire history processing we need to have full history available
            end_ts=ts_now(),
            end_at_least_ts=end_ts,
        )
        result = self.accountant.process_history(
            start_ts,
            end_ts,
            history,
            margin_history,
            loan_history,
            asset_movements,
            eth_transactions,
        )
        return result, error_or_empty

    def query_fiat_balances(self):
        log.info('query_fiat_balances called')
        result = {}
        balances = self.data.get_fiat_balances()
        for currency, amount in balances.items():
            amount = FVal(amount)
            usd_rate = Inquirer().query_fiat_pair(currency, 'USD')
            result[currency] = {
                'amount': amount,
                'usd_value': amount * usd_rate,
            }

        return result

    def query_balances(
        self,
        requested_save_data: bool = False,
        timestamp: Timestamp = None,
    ) -> Dict[str, Any]:
        """Query all balances rotkehlchen can see.

        If requested_save_data is True then the data are saved in the DB.
        If timestamp is None then the current timestamp is used.
        If a timestamp is given then that is the time that the balances are going
        to be saved in the DB

        Returns a dictionary with the queried balances.
        """
        log.info('query_balances called',
                 requested_save_data=requested_save_data)

        balances = {}
        problem_free = True
        for exchange in self.connected_exchanges:
            exchange_balances, _ = getattr(self, exchange).query_balances()
            # If we got an error, disregard that exchange but make sure we don't save data
            if not isinstance(exchange_balances, dict):
                problem_free = False
            else:
                balances[exchange] = exchange_balances

        result, error_or_empty = self.blockchain.query_balances()
        if error_or_empty == '':
            balances['blockchain'] = result['totals']
        else:
            problem_free = False

        result = self.query_fiat_balances()
        if result != {}:
            balances['banks'] = result

        combined = combine_stat_dicts([v for k, v in balances.items()])
        total_usd_per_location = [(k, dict_get_sumof(v, 'usd_value'))
                                  for k, v in balances.items()]

        # calculate net usd value
        net_usd = FVal(0)
        for _, v in combined.items():
            net_usd += FVal(v['usd_value'])

        stats: Dict[str, Any] = {
            'location': {},
            'net_usd': net_usd,
        }
        for entry in total_usd_per_location:
            name = entry[0]
            total = entry[1]
            if net_usd != FVal(0):
                percentage = (total / net_usd).to_percentage()
            else:
                percentage = '0%'
            stats['location'][name] = {
                'usd_value': total,
                'percentage_of_net_value': percentage,
            }

        for k, v in combined.items():
            if net_usd != FVal(0):
                percentage = (v['usd_value'] / net_usd).to_percentage()
            else:
                percentage = '0%'
            combined[k]['percentage_of_net_value'] = percentage

        result_dict = merge_dicts(combined, stats)

        allowed_to_save = requested_save_data or self.data.should_save_balances(
        )
        if problem_free and allowed_to_save:
            if not timestamp:
                timestamp = Timestamp(int(time.time()))
            self.data.save_balances_data(data=result_dict, timestamp=timestamp)
            log.debug('query_balances data saved')
        else:
            log.debug(
                'query_balances data not saved',
                allowed_to_save=allowed_to_save,
                problem_free=problem_free,
            )

        # After adding it to the saved file we can overlay additional data that
        # is not required to be saved in the history file
        try:
            details = self.data.accountant.details
            for asset, (tax_free_amount, average_buy_value) in details.items():
                if asset not in result_dict:
                    continue

                result_dict[asset]['tax_free_amount'] = tax_free_amount
                result_dict[asset]['average_buy_value'] = average_buy_value

                current_price = result_dict[asset]['usd_value'] / result_dict[
                    asset]['amount']
                if average_buy_value != FVal(0):
                    result_dict[asset]['percent_change'] = (
                        ((current_price - average_buy_value) /
                         average_buy_value) * 100)
                else:
                    result_dict[asset]['percent_change'] = 'INF'

        except AttributeError:
            pass

        return result_dict

    def set_main_currency(self, currency):
        with self.lock:
            self.data.set_main_currency(currency, self.accountant)
            if currency != S_USD:
                self.usd_to_main_currency_rate = Inquirer().query_fiat_pair(
                    S_USD, currency)

    def set_settings(self, settings):
        log.info('Add new settings')

        message = ''
        with self.lock:
            if 'eth_rpc_endpoint' in settings:
                result, msg = self.blockchain.set_eth_rpc_endpoint(
                    settings['eth_rpc_endpoint'])
                if not result:
                    # Don't save it in the DB
                    del settings['eth_rpc_endpoint']
                    message += "\nEthereum RPC endpoint not set: " + msg

            if 'main_currency' in settings:
                given_symbol = settings['main_currency']
                try:
                    main_currency = Asset(given_symbol)
                except UnknownAsset:
                    return False, f'Unknown fiat currency {given_symbol} provided'

                if not main_currency.is_fiat():
                    msg = (
                        f'Provided symbol for main currency {given_symbol} is '
                        f'not a fiat currency')
                    return False, msg

                if main_currency != A_USD:
                    self.usd_to_main_currency_rate = Inquirer(
                    ).query_fiat_pair(
                        S_USD,
                        main_currency.identifier,
                    )

            res, msg = self.accountant.customize(settings)
            if not res:
                message += '\n' + msg
                return False, message

            _, msg, = self.data.set_settings(settings, self.accountant)
            if msg != '':
                message += '\n' + msg

            # Always return success here but with a message
            return True, message

    def setup_exchange(
        self,
        name: str,
        api_key: str,
        api_secret: str,
    ) -> Tuple[bool, str]:
        """
        Setup a new exchange with an api key and an api secret

        By default the api keys are always validated unless validate is False.
        """
        log.info('setup_exchange', name=name)
        if name not in SUPPORTED_EXCHANGES:
            return False, 'Attempted to register unsupported exchange {}'.format(
                name)

        if getattr(self, name) is not None:
            return False, 'Exchange {} is already registered'.format(name)

        credentials_dict = {}
        api_credentials = ApiCredentials.serialize(api_key=api_key,
                                                   api_secret=api_secret)
        credentials_dict[name] = api_credentials
        self.initialize_exchanges(credentials_dict)

        exchange = getattr(self, name)
        result, message = exchange.validate_api_key()
        if not result:
            log.error(
                'Failed to validate API key for exchange',
                name=name,
                error=message,
            )
            self.delete_exchange_data(name)
            return False, message

        # Success, save the result in the DB
        self.data.db.add_exchange(name, api_key, api_secret)
        return True, ''

    def delete_exchange_data(self, name):
        self.connected_exchanges.remove(name)
        self.trades_historian.set_exchange(name, None)
        delattr(self, name)
        setattr(self, name, None)

    def remove_exchange(self, name):
        if getattr(self, name) is None:
            return False, 'Exchange {} is not registered'.format(name)

        self.delete_exchange_data(name)
        # Success, remove it also from the DB
        self.data.db.remove_exchange(name)
        return True, ''

    def query_periodic_data(self) -> Dict[str, Union[bool, Timestamp]]:
        """Query for frequently changing data"""
        result = {}

        if self.user_is_logged_in:
            result[
                'last_balance_save'] = self.data.db.get_last_balance_save_time(
                )
            result['eth_node_connection'] = self.blockchain.ethchain.connected
            result[
                'history_process_start_ts'] = self.accountant.started_processing_timestamp
            result[
                'history_process_current_ts'] = self.accountant.currently_processing_timestamp
        return result

    def shutdown(self):
        self.logout()
        self.shutdown_event.set()
Beispiel #55
0
class ThreadPool(GroupMappingMixin):

    def __init__(self, maxsize, hub=None):
        if hub is None:
            hub = get_hub()
        self.hub = hub
        self._maxsize = 0
        self.manager = None
        self.pid = os.getpid()
        self.fork_watcher = hub.loop.fork(ref=False)
        self._init(maxsize)

    def _set_maxsize(self, maxsize):
        if not isinstance(maxsize, integer_types):
            raise TypeError('maxsize must be integer: %r' % (maxsize, ))
        if maxsize < 0:
            raise ValueError('maxsize must not be negative: %r' % (maxsize, ))
        difference = maxsize - self._maxsize
        self._semaphore.counter += difference
        self._maxsize = maxsize
        self.adjust()
        # make sure all currently blocking spawn() start unlocking if maxsize increased
        self._semaphore._start_notify()

    def _get_maxsize(self):
        return self._maxsize

    maxsize = property(_get_maxsize, _set_maxsize)

    def __repr__(self):
        return '<%s at 0x%x %s/%s/%s>' % (self.__class__.__name__, id(self), len(self), self.size, self.maxsize)

    def __len__(self):
        # XXX just do unfinished_tasks property
        return self.task_queue.unfinished_tasks

    def _get_size(self):
        return self._size

    def _set_size(self, size):
        if size < 0:
            raise ValueError('Size of the pool cannot be negative: %r' % (size, ))
        if size > self._maxsize:
            raise ValueError('Size of the pool cannot be bigger than maxsize: %r > %r' % (size, self._maxsize))
        if self.manager:
            self.manager.kill()
        while self._size < size:
            self._add_thread()
        delay = 0.0001
        while self._size > size:
            while self._size - size > self.task_queue.unfinished_tasks:
                self.task_queue.put(None)
            if getcurrent() is self.hub:
                break
            sleep(delay)
            delay = min(delay * 2, .05)
        if self._size:
            self.fork_watcher.start(self._on_fork)
        else:
            self.fork_watcher.stop()

    size = property(_get_size, _set_size)

    def _init(self, maxsize):
        self._size = 0
        self._semaphore = Semaphore(1)
        self._lock = Lock()
        self.task_queue = Queue()
        self._set_maxsize(maxsize)

    def _on_fork(self):
        # fork() only leaves one thread; also screws up locks;
        # let's re-create locks and threads
        pid = os.getpid()
        if pid != self.pid:
            self.pid = pid
            # Do not mix fork() and threads; since fork() only copies one thread
            # all objects referenced by other threads has refcount that will never
            # go down to 0.
            self._init(self._maxsize)

    def join(self):
        delay = 0.0005
        while self.task_queue.unfinished_tasks > 0:
            sleep(delay)
            delay = min(delay * 2, .05)

    def kill(self):
        self.size = 0

    def _adjust_step(self):
        # if there is a possibility & necessity for adding a thread, do it
        while self._size < self._maxsize and self.task_queue.unfinished_tasks > self._size:
            self._add_thread()
        # while the number of threads is more than maxsize, kill one
        # we do not check what's already in task_queue - it could be all Nones
        while self._size - self._maxsize > self.task_queue.unfinished_tasks:
            self.task_queue.put(None)
        if self._size:
            self.fork_watcher.start(self._on_fork)
        else:
            self.fork_watcher.stop()

    def _adjust_wait(self):
        delay = 0.0001
        while True:
            self._adjust_step()
            if self._size <= self._maxsize:
                return
            sleep(delay)
            delay = min(delay * 2, .05)

    def adjust(self):
        self._adjust_step()
        if not self.manager and self._size > self._maxsize:
            # might need to feed more Nones into the pool
            self.manager = Greenlet.spawn(self._adjust_wait)

    def _add_thread(self):
        with self._lock:
            self._size += 1
        try:
            start_new_thread(self._worker, ())
        except:
            with self._lock:
                self._size -= 1
            raise

    def spawn(self, func, *args, **kwargs):
        while True:
            semaphore = self._semaphore
            semaphore.acquire()
            if semaphore is self._semaphore:
                break
        try:
            task_queue = self.task_queue
            result = AsyncResult()
            thread_result = ThreadResult(result, hub=self.hub)
            task_queue.put((func, args, kwargs, thread_result))
            self.adjust()
            # rawlink() must be the last call
            result.rawlink(lambda *args: self._semaphore.release())
            # XXX this _semaphore.release() is competing for order with get()
            # XXX this is not good, just make ThreadResult release the semaphore before doing anything else
        except:
            semaphore.release()
            raise
        return result

    def _decrease_size(self):
        if sys is None:
            return
        _lock = getattr(self, '_lock', None)
        if _lock is not None:
            with _lock:
                self._size -= 1

    def _worker(self):
        need_decrease = True
        try:
            while True:
                task_queue = self.task_queue
                task = task_queue.get()
                try:
                    if task is None:
                        need_decrease = False
                        self._decrease_size()
                        # we want first to decrease size, then decrease unfinished_tasks
                        # otherwise, _adjust might think there's one more idle thread that
                        # needs to be killed
                        return
                    func, args, kwargs, thread_result = task
                    try:
                        value = func(*args, **kwargs)
                    except:
                        exc_info = getattr(sys, 'exc_info', None)
                        if exc_info is None:
                            return
                        thread_result.handle_error((self, func), exc_info())
                    else:
                        if sys is None:
                            return
                        thread_result.set(value)
                        del value
                    finally:
                        del func, args, kwargs, thread_result, task
                finally:
                    if sys is None:
                        return
                    task_queue.task_done()
        finally:
            if need_decrease:
                self._decrease_size()

    def apply_e(self, expected_errors, function, args=None, kwargs=None):
        # Deprecated but never documented. In the past, before
        # self.apply() allowed all errors to be raised to the caller,
        # expected_errors allowed a caller to specify a set of errors
        # they wanted to be raised, through the wrap_errors function.
        # In practice, it always took the value Exception or
        # BaseException.
        return self.apply(function, args, kwargs)

    def _apply_immediately(self):
        # we always pass apply() off to the threadpool
        return False

    def _apply_async_cb_spawn(self, callback, result):
        callback(result)

    def _apply_async_use_greenlet(self):
        # Always go to Greenlet because our self.spawn uses threads
        return True
Beispiel #56
0
class Mime(object):
    """
    Our mime response object.
    """
    def __init__(self, magic_file=None):
        """
        Create a new libmagic wrapper.
        magic_file - Another magic file other then the default
        """

        # The lock allows us to be thread-safe
        self.lock = Semaphore(value=1)

        # Tracks the errno/errstr set after each call
        self.errno = 0
        self.errstr = ''

        # Load our magic file
        self.magic_file = magic_file

        # Initialize our flags
        # our flags
        self.flags = (MAGIC_MIME | MAGIC_MIME_ENCODING)

    def from_content(self, content, uncompress=False, fullscan=False):
        """
        detect the type based on the content specified.

          content - the content to process
          uncompress - Try to look inside compressed files.
          fullscan - Scan entire file and extract as much details as possible

        """
        if not HAS_LIBMAGIC:
            return None

        if not content:
            return MimeResponse()

        # our pointer to our libmagic object
        ptr = None

        # our flags
        flags = self.flags

        if fullscan:
            flags |= MAGIC_CONTINUE

        if uncompress:
            flags |= MAGIC_COMPRESS

        try:
            self.lock.acquire(blocking=True)
            # 'application/octet-stream; charset=binary'

            # Acquire a pointer
            ptr = _magic['open'](flags)

            # Acquire a pointer
            _magic['load'](ptr, self.magic_file)
            self.errno = _magic['errno'](ptr)
            if self.errno == 0:

                # Achieve our results as a list
                _typ, _enc = MAGIC_LIST_RE.split(self._tostr(_magic['buffer'](
                    ptr,
                    self._tobytes(content),
                    len(content),
                )))

                # Acquire our errorstr (if one exists)
                self.errstr = _magic['error'](ptr)

                _typs = []
                for n, _typ_re in enumerate(TYPE_PARSE_RE.finditer(_typ)):
                    if n is 0:
                        _typs.append((0, _typ_re.group('mtype')))
                    else:
                        _typs.append(
                            (
                                int(_typ_re.group('offset')),
                                _typ_re.group('mtype'),
                            ),
                        )

                _enc_re = ENCODING_PARSE_RE.match(_enc)
                if _enc_re:
                    _enc = _enc_re.group('encoding')

                mr = MimeResponse(mime_type=_typs, mime_encoding=_enc)

                # return our object
                return mr

        except TypeError:
            # This occurs if buffer check returns None
            # Acquire our errorstr (if one exists)
            self.errstr = _magic['error'](ptr)
            if self.errstr:
                # an error occured; return None
                return None

            # If we get here, we didn't even get an error
            return MimeResponse(
                mime_type=DEFAULT_MIME_TYPE,
                mime_encoding=DEFAULT_MIME_ENCODING,
            )

        finally:
            if ptr is not None:
                # Release our pointer
                _magic['close'](ptr)

            # Release our lock
            self.lock.release()

        # We failed if we got here, return nothing
        return None

    def from_file(self, path, uncompress=False, fullscan=False):
        """
        detect the type based on the content specified.

          path - the file to process
          uncompress - Try to look inside compressed files.
          fullscan - Scan entire file and extract as much details as possible

        """
        if not HAS_LIBMAGIC:
            return None

        if not path:
            return None

        # our pointer to our libmagic object
        ptr = None

        # our flags
        flags = self.flags

        if fullscan:
            flags |= MAGIC_CONTINUE

        if uncompress:
            flags |= MAGIC_COMPRESS

        # Default Response
        res = ""

        try:
            self.lock.acquire(blocking=True)
            # 'application/octet-stream; charset=binary'

            # Acquire a pointer
            ptr = _magic['open'](flags)

            # Acquire a pointer
            _magic['load'](ptr, self.magic_file)
            self.errno = _magic['errno'](ptr)
            if self.errno == 0:

                res = self._tostr(_magic['file'](ptr, self._tobytes(path)))

                # Acquire our errorstr (if one exists)
                self.errstr = _magic['error'](ptr)

                # Achieve our results as a list
                _typ, _enc = MAGIC_LIST_RE.split(res)

                _typs = []
                for n, _typ_re in enumerate(TYPE_PARSE_RE.finditer(_typ)):
                    if n is 0:
                        _typs.append((0, _typ_re.group('mtype')))
                    else:
                        _typs.append(
                            (
                                int(_typ_re.group('offset')),
                                _typ_re.group('mtype')
                            ),
                        )

                _enc_re = ENCODING_PARSE_RE.match(_enc)
                if _enc_re:
                    _enc = _enc_re.group('encoding')

                mr = MimeResponse(
                    mime_type=_typs,
                    mime_encoding=_enc,
                    extension=self.extension_from_filename(path),
                )

                # return our object
                return mr

        except TypeError:
            # This occurs if buffer check returns None
            # Acquire our errorstr (if one exists)
            self.errstr = _magic['error'](ptr)
            if self.errstr:
                # an error occured; return None
                return None

            # If we get here, we didn't even get an error
            return MimeResponse(
                mime_type=DEFAULT_MIME_TYPE,
                mime_encoding=DEFAULT_MIME_ENCODING,
                extension=self.extension_from_filename(path),
            )

        except ValueError:
            # This occurs during our regular expression extraction which
            # couldn't accomplish it's feat because an error string (such as
            # no file exists) was returned as part of the response instead
            # of our expected mime type. Store this error
            self.errstr = res

            # Now return None
            return None

        finally:
            if ptr is not None:
                # Release our pointer
                _magic['close'](ptr)

            # Release our lock
            self.lock.release()

        # We failed if we got here, return nothing
        return None

    def from_filename(self, filename):
        """
        detect the type based on the filename (and/or extension) specified.

        """

        if not filename:
            # Invalid
            return None

        mime_type = next((m for m in MIME_TYPES if m[1].match(filename)), None)

        if mime_type:
            return MimeResponse(
                mime_type=mime_type[0],
                mime_encoding=mime_type[2],
                extension=self.extension_from_filename(filename),
            )

        # No match; default response
        return MimeResponse(
            mime_type=DEFAULT_MIME_TYPE,
            mime_encoding=DEFAULT_MIME_ENCODING,
            extension=self.extension_from_filename(filename),
        )

    def from_bestguess(self, path):
        """
        First attempts to look at the files contents (if it can), if not it
        falls back to looking at the filename and always returns it's best
        guess.
        """
        mr = self.from_file(path)
        if mr is None or mr.type() in \
                (DEFAULT_MIME_TYPE, DEFAULT_MIME_EMTPTY_FILE):

            _mr = self.from_filename(path)
            if _mr is None:
                if mr is None:
                    # Not parseable
                    return None

                # Otherwise return a stream and intentionally
                # do not count the empty file (if detected)
                return MimeResponse(
                    mime_type=DEFAULT_MIME_TYPE,
                    mime_encoding=DEFAULT_MIME_ENCODING,
                    extension=self.extension_from_filename(path),
                )

            elif mr is not None and _mr == DEFAULT_MIME_TYPE:
                return mr
            return _mr
        return mr

    def extension_from_mime(self, mime_type):
        """
        takes a mime type and returns the file extension that bests matches
        it. This function returns an empty string if the mime type can't
        be looked up correctly, otherwise it returns the matching extension.
        """

        if not mime_type:
            # Invalid; but return an empty string
            return ''

        # iterate over our list and return on our first match
        return next((m[3] for m in MIME_TYPES if mime_type == m[0]), '')

    def extension_from_filename(self, filename):
        """
        Takes a filename and extracts the extension from it (if possible).
        This function is a little like the os.path.splitext() which can
        take a file like:  test.jpg and return the .jpg. But it can not
        however handle names like text.jpeg.gz (to which it would just
        return .gz in this case.

        This function (specifically written for newsreap/usenet parsing)
        will attempt to extract a worthy extension (all of it) from the
        filename. Hence:

            extension_from_filename('myfile.pdf.vol03+4.par2')
               will return '.pdf.vol03+4.par2'

        """

        if not filename:
            # Invalid; but return an empty string
            return ''

        result = EXTENSION_SEARCH_RE.search(filename)
        if result:
            ext = result.group('ext')
            if ext is not None:
                return ext

        # Nothing found if we reach here
        return ''

    def _tostr(self, s, encoding='utf-8'):
        if s is None:
            return None
        if isinstance(s, str):
            return s
        try:  # keep Python 2 compatibility
            return str(s, encoding)
        except TypeError:
            return str(s)

    def _tobytes(self, b, encoding='utf-8'):
        if b is None:
            return None
        if isinstance(b, bytes):
            return b
        try:  # keep Python 2 compatibility
            return bytes(b, encoding)
        except TypeError:
            return bytes(b)