예제 #1
0
class ConditionPoller(Thread):
    """
    generic polling mechanism: every interval seconds, check if condition returns a true value. if so, pass the value to callback
    if condition or callback raise exception, stop polling.
    """
    def __init__(self, condition, condition_callback, exception_callback, interval):
        self.polling_interval = interval
        self._shutdown_now = Event()
        self._condition = condition
        self._callback = condition_callback
        self._on_exception = exception_callback
        super(ConditionPoller,self).__init__()
    def shutdown(self):
        self.is_shutting_down = True
        self._shutdown_now.set()
    def run(self):
        try:
            while not self._shutdown_now.is_set():
                self._check_condition()
                self._shutdown_now.wait(self.polling_interval)
        except:
            log.error('thread failed', exc_info=True)
    def _check_condition(self):
        try:
            value = self._condition()
            if value:
                self._callback(value)
        except Exception as e:
            log.debug('stopping poller after exception', exc_info=True)
            self.shutdown()
            if self._on_exception:
                self._on_exception(e)
    def start(self):
        super(ConditionPoller,self).start()
예제 #2
0
파일: views.py 프로젝트: kblw/simple-chat
class Chat(object):
    def __init__(self):
        self.new_msg_event = Event()

    def write_message(self, request):
        if not request.user.is_authenticated() or request.method != 'POST':
            return HttpResponse(status=404)
        form = MessageForm(request.POST)
        output = dict(success=False)
        if form.is_valid():
            form.save(request.user)
            output['success'] = True
        else:
            output['errors'] = form.get_errors()
        self.new_msg_event.set()
        self.new_msg_event.clear()
        return HttpResponse(json.dumps(output))

    def get_messages(self, request):
        if not request.user.is_authenticated():
            return HttpResponse(status=404)
        pk = int(request.GET.get('pk', 1))
        messages = [{'created_at': DateFormat(el.created_at).format('H:i:s'),
                    'username': el.username, 'pk': el.pk,
                    'msg': el.msg} for el in Message.objects.filter(pk__gt=int(pk))
                                            .order_by('-created_at')[:100]]
        if not messages:
            self.new_msg_event.wait()
        return HttpResponse(json.dumps(messages[::-1]))
예제 #3
0
class TestEvent(Actor):

    '''**Generates a test event at the chosen interval.**

    This module is only available for testing purposes and has further hardly any use.

    Events have following format:

        { "header":{}, "data":"test" }

    Parameters:

        - name (str):               The instance name when initiated.

        - interval (float):         The interval in seconds between each generated event.
                                    Should have a value > 0.
                                    default: 1

    Queues:

        - outbox:    Contains the generated events.
    '''

    def __init__(self, name, interval=1):
        Actor.__init__(self, name, setupbasic=False)
        self.createQueue("outbox")
        self.name = name
        self.interval=interval
        if interval == 0:
            self.sleep = self.doNoSleep
        else:
            self.sleep = self.doSleep

        self.throttle=Event()
        self.throttle.set()

    def preHook(self):
        spawn(self.go)

    def go(self):
        switcher = self.getContextSwitcher(100)
        while switcher():
            self.throttle.wait()
            try:
                self.queuepool.outbox.put({"header":{},"data":"test"})
            except (QueueFull, QueueLocked):
                self.queuepool.outbox.waitUntilPutAllowed()
            self.sleep(self.interval)

    def doSleep(self, interval):
        sleep(interval)

    def doNoSleep(self, interval):
        pass

    def enableThrottling(self):
        self.throttle.clear()

    def disableThrottling(self):
        self.throttle.set()
예제 #4
0
class GServer(ProtoBufRPCServer):
    def __init__(self, host, port, service, poolsize=128):
        self.gpool = Pool(poolsize)
        self.stop_event = Event()
        context = zmq.Context()
        self.port = port
        self.socket = context.socket(zmq.ROUTER)
        self.socket.bind("tcp://%s:%s" % (host, port))
        self.service = service

    def serve_forever(self,):
        while not self.stop_event.is_set():
            try:
                msg = self.socket.recv_multipart()
            except zmq.ZMQError:
                if self.socket.closed:
                    break
                raise e
            self.gpool.spawn(self.handle_request, msg)

    def shutdown(self,):
        self.socket.close()
        self.stop_event.set()

    def handle_request(self, msg):
        assert len(msg) == 3
        (id_, null, request) = msg
        assert null == ''
        response = self.handle(request)
        self.socket.send_multipart([id_, null, response.SerializeToString()])
예제 #5
0
class InputStream(object):
    """
    FCGI_STDIN or FCGI_DATA stream.
    Uses temporary file to store received data once max_mem bytes
    have been received.
    """
    def __init__(self, max_mem=1024):
        self._file = SpooledTemporaryFile(max_mem)
        self._eof_received = Event()

    def feed(self, data):
        if self._eof_received.is_set():
            raise IOError('Feeding file beyond EOF mark')
        if not data:  # EOF mark
            self._file.seek(0)
            self._eof_received.set()
        else:
            self._file.write(data)

    def __iter__(self):
        self._eof_received.wait()
        return iter(self._file)

    def read(self, size=-1):
        self._eof_received.wait()
        return self._file.read(size)

    def readlines(self, sizehint=0):
        self._eof_received.wait()
        return self._file.readlines(sizehint)

    @property
    def eof_received(self):
        return self._eof_received.is_set()
예제 #6
0
    def test_completion_vs_session(self):
        h = self._makeOne()

        lst = []
        av = Event()
        bv = Event()
        cv = Event()
        dv = Event()

        def addToList():
            av.set()
            bv.wait()
            lst.append(True)
            dv.set()

        def anotherAdd():
            lst.append(True)
            cv.set()

        call1 = Callback('session', addToList, ())
        call2 = Callback('completion', anotherAdd, ())

        h.dispatch_callback(call1)
        av.wait()
        # Now we know the first is waiting, make sure
        # the second executes while the first has blocked
        # its thread
        h.dispatch_callback(call2)
        cv.wait()

        eq_(lst, [True])
        bv.set()
        dv.wait()
        eq_(lst, [True, True])
예제 #7
0
class Worker(object):
    '''
    子进程运行的代码,通过起一个协程来和主进程通信
    包括接受任务分配请求,退出信号(零字节包),及反馈任务执行进度
    然后主协程等待停止信号并中止进程(stop_event用于协程间同步)。
    '''
    def __init__(self, url):
        self.url = url
        self.stop_event = Event()
        gevent.spawn(self.communicate)
        self.stop_event.wait()
        print 'worker(%s):will stop' % os.getpid()
    def exec_task(self, task):
        print 'worker(%s):execute task:%s' % (os.getpid(), task.rstrip('\n'))
    def communicate(self):
        print 'worker(%s):started' % os.getpid()
        client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        client.connect(self.url)
        fp = client.makefile()
        while True:
            line = fp.readline()
            if not line:
                self.stop_event.set()
                break
            '单独起一个协程去执行任务,防止通信协程阻塞'
            gevent.spawn(self.exec_task, line)
예제 #8
0
    def test_heartbeat_with_listeners(self):
        mocklistener = Mock(spec=ProcessRPCServer)
        svc = self._make_service()
        p = IonProcessThread(name=sentinel.name, listeners=[mocklistener], service=svc)
        readyev = Event()
        readyev.set()
        mocklistener.get_ready_event.return_value = readyev

        def fake_listen(evout, evin):
            evout.set(True)
            evin.wait()

        listenoutev = AsyncResult()
        listeninev = Event()

        mocklistener.listen = lambda *a, **kw: fake_listen(listenoutev, listeninev)

        p.start()
        p.get_ready_event().wait(timeout=5)
        p.start_listeners()

        listenoutev.wait(timeout=5)         # wait for listen loop to start

        self.addCleanup(listeninev.set)     # makes listen loop fall out on shutdown
        self.addCleanup(p.stop)

        # now test heartbeat!
        hb = p.heartbeat()

        self.assertEquals((True, True, True), hb)
        self.assertEquals(0, p._heartbeat_count)
        self.assertIsNone(p._heartbeat_op)
예제 #9
0
    def test_heartbeat_listener_dead(self):
        mocklistener = Mock(spec=ProcessRPCServer)
        svc = self._make_service()
        p = IonProcessThread(name=sentinel.name, listeners=[mocklistener], service=svc)
        readyev = Event()
        readyev.set()
        mocklistener.get_ready_event.return_value = readyev

        def fake_listen(evout, evin):
            evout.set(True)
            evin.wait()

        listenoutev = AsyncResult()
        listeninev = Event()

        p.start()
        p.get_ready_event().wait(timeout=5)
        p.start_listeners()

        listenoutev.wait(timeout=5)         # wait for listen loop to start

        self.addCleanup(listeninev.set)     # makes listen loop fall out on shutdown
        self.addCleanup(p.stop)

        listeninev.set()                    # stop the listen loop
        p.thread_manager.children[1].join(timeout=5)        # wait for listen loop to terminate

        hb = p.heartbeat()

        self.assertEquals((False, True, True), hb)
        self.assertEquals(0, p._heartbeat_count)
        self.assertIsNone(p._heartbeat_op)
예제 #10
0
class Pinger(Greenlet):
    """ Very simple test 'app'
    """
    def __init__(self, id):
        super(Pinger,self).__init__()
        self.event = Event()
        self.conn = None
        self.id = id

    def _run(self):
        logger.debug("Pinger starting")
        self.conn = connection.AMQPConnection(self)
        self.conn.connect()
        #self.conn.connection.join()
        self.event.wait()
        logger.debug("Pinger exiting")
        self.amqp.close()
        self.conn.close()

    def on_connect(self, connection):
        self.amqp = PingerAMQPManager(connection, self, self.id)

    def handle_message(self, message):
        if message.routing_key.endswith('pinger.exit'):
            #self.conn.connection.close()
            self.event.set()
예제 #11
0
파일: c2dm.py 프로젝트: soulsharepj/zdzl
class C2DMService(object):
    def __init__(self, source, email, password):
        self.source = source
        self.email = email
        self.password = password
        self._send_queue = Queue()
        self._send_queue_cleared = Event()
        self.log = logging.getLogger('pulsus.service.c2dm')

    def _send_loop(self):
        self._send_greenlet = gevent.getcurrent()
        try:
            self.log.info("C2DM service started")
            while True:
                notification = self._send_queue.get()
                try:
                    self._do_push(notification)
                except Exception, e:
                    self.log.exception("Error while pushing")
                    self._send_queue.put(notification)
                    gevent.sleep(5.0)
                finally:
                    if self._send_queue.qsize() < 1 and \
                            not self._send_queue_cleared.is_set():
                        self._send_queue_cleared.set()
예제 #12
0
파일: views.py 프로젝트: go4ble/posty
class Chat(object):
    def __init__(self):
        # at some point, may want to implement a buffer for messages
        # self.buffer = []
        self.msg_event = Event()

    def index(self, request):
        form = models.MsgForm()
        msg_list = models.Msg.objects.all()
        return render(request, 'index.html', {
            'form': form,
            'msg_list': msg_list,
        })

    def send(self, request):
        if request.method == 'POST':
            form = models.MsgForm(request.POST)
            if form.is_valid():
                form.save()
                # tell everyone who's waiting on msg_event that a msg was just
                # posted
                self.msg_event.set()
                self.msg_event.clear()
                return HttpResponse(json.dumps(True), mimetype='application/json')
        return HttpResponse(json.dumps(False), mimetype='application/json')

    def update(self, request):
        check_time = datetime.now()
        # wait for next msg post
        self.msg_event.wait()
        msg_list = models.Msg.objects.filter(time_stamp__gte=check_time)
        return HttpResponse(serializers.serialize('xml', msg_list), mimetype='text/xml')
예제 #13
0
def test_client_shutdown(copper_client):
    have_stream = Event()
    may_respond = Event()
    def handler(stream):
        have_stream.set()
        may_respond.wait()
        stream.write('Hello, world!')
    with copper_client.publish('test:helloworld', handler):
        with copper_client.subscribe('test:helloworld') as sub:
            with sub.open() as stream:
                have_stream.wait()
                with pytest.raises(Timeout):
                    with Timeout(0.005):
                        # This initiates the shutdown, but it should not
                        # complete on its own, because handler is still
                        # running. Being stopped with a timeout does not
                        # stop the shutdown procedure.
                        copper_client.shutdown(unpublish=False)
                # Verify our handler can still reply successfully
                may_respond.set()
                assert stream.read() == 'Hello, world!'
            with sub.open() as stream:
                # Verify any new streams fail with ECONNSHUTDOWN (since our
                # code didn't unpublish the service), and don't reach our
                # handler.
                with pytest.raises(ConnectionShutdownError):
                    stream.read()
            # Verify shutdown now finishes successfully.
            copper_client.shutdown()
예제 #14
0
class _Registration(object):
    """A service registration."""

    def __init__(self, client, form_name, instance_name, data,
                 interval=3):
        self.stopped = Event()
        self.client = client
        self.form_name = form_name
        self.instance_name = instance_name
        self.data = data
        self.interval = interval
        self.gthread = None

    def _loop(self):
        uri = '/%s/%s' % (self.form_name, self.instance_name)
        while not self.stopped.isSet():
            response = self.client._request('PUT', uri,
                                            data=json.dumps(self.data))
            self.stopped.wait(self.interval)

    def start(self):
        self.gthread = gevent.spawn(self._loop)
        return self

    def stop(self, timeout=None):
        self.stopped.set()
        self.gthread.join(timeout)
예제 #15
0
    def test_semaphore(self):
        edge = APIEdge(MockApp(), self.get_settings())
        api = edge.app.api
        edge.max_concurrent_calls = 1

        in_first_method = Event()
        finish_first_method = Event()
        def first_method():
            in_first_method.set()
            finish_first_method.wait()
        api.first_method = first_method

        in_second_method = Event()
        def second_method():
            in_second_method.set()
        api.second_method = second_method

        gevent.spawn(edge.execute, Call("first_method"))
        in_first_method.wait()

        gevent.spawn(edge.execute, Call("second_method"))
        gevent.sleep(0)

        assert_logged("too many concurrent callers")
        assert not in_second_method.is_set()

        finish_first_method.set()
        in_second_method.wait()
        self.assert_edge_clean(edge)
예제 #16
0
    def running(self, qpid_handle, frequency=3):
        '''context with the agent serving in "background"'''

        def job(stop):
            """running qpid handle"""
            from qpid_handle import Timeout, Empty

            while True:
                gevent.sleep()
                if stop.is_set():
                    log.debug("joining %r" % self)
                    break
                try:
                    with qpid_handle.timeout(1.0 / frequency):
                        self(qpid_handle)
                except (Timeout, Empty) as e:
                    log.debug("no messages %r" % self)

        def exception_handler(greenlet):
            gevent.get_hub().parent.throw(greenlet.exception)

        stop = Event()
        agent = gevent.spawn(job, stop)
        agent.link_exception(exception_handler)

        try:
            yield self
        finally:
            stop.set()
            gevent.sleep()
            agent.join()
예제 #17
0
class _FormationCache(object):
    """A cache of instance data for a formation."""

    def __init__(self, client, form_name, factory, interval):
        self.client = client
        self.form_name = form_name
        self.factory = factory
        self.interval = interval
        self._gthread = None
        self._cache = {}
        self._stopped = Event()
        self._running = Event()

    def start(self):
        self._gthread = gevent.spawn(self._loop)
        self._running.wait(timeout=0.1)
        return self

    def stop(self, timeout=None):
        self._stopped.set()
        self._gthread.join(timeout)

    def _update(self):
        self._cache = self.client.query_formation(self.form_name,
                                                  self.factory)

    def _loop(self):
        while not self._stopped.isSet():
            self._update()
            self._running.set()
            self._stopped.wait(self.interval)

    def query(self):
        """Return all instances and their names."""
        return dict(self._cache)
예제 #18
0
파일: client.py 프로젝트: jmakov/market_tia
class StdioPipedWebSocketClient(WebSocketClient):

    def __init__(self, scheme, host, port, path, opts):
        url = "{0}://{1}:{2}{3}".format(scheme, host, port, path)
        WebSocketClient.__init__(self, url)

        self.path = path
        self.shutdown_cond = Event()
        self.opts = opts
        self.iohelper = common.StdioPipedWebSocketHelper(self.shutdown_cond, opts)

    def received_message(self, m):
        #TODO here we can retrieve the msg
        self.iohelper.received_message(self, m)

    def opened(self):
        if self.opts.verbosity >= 1:
            peername, peerport = self.sock.getpeername()
            print >> sys.stderr, "[%s] %d open for path '%s'" % (peername, peerport, self.path)
        self.iohelper.opened(self)

    def closed(self, code, reason):
        self.shutdown_cond.set()

    def connect_and_wait(self):
        self.connect()
        self.shutdown_cond.wait()
예제 #19
0
class StreamInfoNotifier(object):

    MAX_WAITER = 1000

    def __init__(self, stream_name):
        self.stream_name = stream_name
        self.stream_info = None
        self.ready_event = Event()
        self.cur_waiter_num = 0

    def wait_stream_info(self, timeout):
        if self.cur_waiter_num > self.MAX_WAITER:
            raise StreamSwitchError("Too Many Waiter for Stream Info (%s)" %
                                    self.stream_name, 503)
        self.cur_waiter_num += 1

        if not self.ready_event.wait(timeout=timeout):
            # timeout
            raise StreamSwitchError(
                "Stream Info (%s) Not Ready" % self.stream_name, 408)

        return self.stream_info

    def put_stream_info(self, stream_info):
        self.stream_info = stream_info
        self.ready_event.set()
예제 #20
0
파일: rack.py 프로젝트: creotiv/gflib
class ServerRack(object):

    def __init__(self, servers):
        self.servers = servers
        self.ev = Event()

    def start(self):
        started = []
        try:
            for server in self.servers[:]:
                server.start()
                started.append(server)
                name = getattr(server, 'name', None) or server.__class__.__name__ or 'Server'
        except:
            self.stop(started)
            raise
        
    def serve_forever(self):
        self.start()
        self.ev.wait() 

    def stop(self, servers=None):
        if servers is None:
            servers = self.servers[:]
        for server in servers:
            try:
                server.stop()
            except:
                if hasattr(server, 'loop'): #gevent >= 1.0
                    server.loop.handle_error(server.stop, *sys.exc_info())
                else: # gevent <= 0.13
                    import traceback
                    traceback.print_exc()
        self.ev.set()
예제 #21
0
class BlackBerryPushService(object):
    def __init__(self, app_id, password, push_url):
        self.app_id = app_id
        self.password = password
        self.push_url = push_url
        self._send_queue = Queue()
        self._send_queue_cleared = Event()
        self.log = logging.getLogger('pulsus.service.bbp')

    def _send_loop(self):
        self._send_greenlet = gevent.getcurrent()
        try:
            self.log.info("BlackBerry Push service started")
            while True:
                notification = self._send_queue.get()
                try:
                    self._do_push(notification)
                except Exception, e:
                    print e
                    self._send_queue.put(notification)
                    gevent.sleep(5.0)
                finally:
                    if self._send_queue.qsize() < 1 and \
                            not self._send_queue_cleared.is_set():
                        self._send_queue_cleared.set()
예제 #22
0
    def test_greenlet(self):

        queue = JoinableQueue()
        requests_done = Event()

        g = Greenlet(self._producer, queue, FirstService(), 'Terminator')
        h = Greenlet(self._producer, queue, SecondService(), 'Terminator')
        i = Greenlet(self._producer, queue, ThirdService(), 'Terminator')

        requests = Group()

        for request in g, h, i:
            requests.add(request)

        log.debug('before spawn')

        c = spawn(
            self._consumer,
            done=requests_done,
            queue=queue,
        )
        [r.start() for r in requests]

        log.debug('after spawn')

        requests.join()
        requests_done.set()

        log.debug('requests are done')

        c.join()

        log.debug('consumer is done')
예제 #23
0
class GeventCursor(net.Cursor):
    def __init__(self, *args, **kwargs):
        super(GeventCursor, self).__init__(*args, **kwargs)
        self.new_response = Event()

    def __iter__(self):
        return self

    def __next__(self):
        return self._get_next(None)

    def _empty_error(self):
        return GeventCursorEmpty()

    def _extend(self, res):
        super(GeventCursor, self)._extend(res)
        self.new_response.set()
        self.new_response.clear()

    def _get_next(self, timeout):
        with gevent.Timeout(timeout, RqlTimeoutError()) as timeout:
            self._maybe_fetch_batch()
            while len(self.items) == 0:
                if self.error is not None:
                    raise self.error
                self.new_response.wait()
            return self.items.popleft()
예제 #24
0
def handle():
    connection = create_postgresql_connection()

    cursor = connection.cursor()
    cursor.execute("BEGIN;")
    cursor.execute("DELETE FROM core_ratequery;")
    cursor.execute("COMMIT;")
    cursor.close()

    queue = JoinableQueue()
    event = Event()

    age_ids = age_map(connection).values() + [None]
    sex_ids = sex_map(connection).values() + [None]
    education_ids = education_map(connection).values() + [None]
    province_ids = province_map(connection).values() + [None]

    cursor = connection.cursor()
    cursor.execute("SELECT DISTINCT cycle FROM core_microdata;");
    cycles = [row[0] for row in cursor]
    cursor.close()

    greenlets = []

    for i in range(50):
        gv = gevent.spawn(worker, queue, event)
        greenlets.append(gv)

    combs = itertools.product(age_ids, sex_ids, province_ids, education_ids, cycles)
    for c in combs:
        queue.put(c)

    queue.join()
    event.set()
    gevent.joinall(greenlets)
예제 #25
0
class TestTrasforms(IonIntegrationTestCase):
    def setUp(self):
        self._start_container()
        self.queue_cleanup = []
        self.exchange_cleanup = []

    def tearDown(self):
        for queue in self.queue_cleanup:
            xn = self.container.ex_manager.create_xn_queue(queue)
            xn.delete()
        for exchange in self.exchange_cleanup:
            xp = self.container.ex_manager.create_xp(exchange)
            xp.delete()
        
    def test_stats(self):
        self.container.spawn_process('test','ion.core.process.transform','TransformBase', {}, 'test_transform')
        test_transform = self.container.proc_manager.procs['test_transform']
        test_transform._stats['hits'] = 100

        retval = TransformBase.stats('test_transform')
        self.assertEquals(retval,{'hits':100})


    def test_stream_transforms(self):

        self.verified = Event()
        input_route = StreamRoute('test_exchange','input')
        output_route = StreamRoute('test_exchange','output')
        def verify(m, route, stream_id):
            self.assertEquals(route,output_route)
            self.assertEquals(m,'test')
            self.verified.set()
        
        #                       Create I/O Processes
        #--------------------------------------------------------------------------------

        pub_proc = TransformBase()
        pub_proc.container = self.container
        publisher = StreamPublisher(process=pub_proc, stream_route=input_route)
        

        transform = self.container.spawn_process('transform','ion.core.process.test.test_transform','EmptyDataProcess',{'process':{'queue_name':'transform_input', 'exchange_point':output_route.exchange_point, 'routing_key':output_route.routing_key}}, 'transformpid')
        transform = self.container.proc_manager.procs[transform]

        sub_proc = TransformBase()
        sub_proc.container = self.container
        subscriber = StreamSubscriber(process=sub_proc, exchange_name='subscriber', callback=verify)

        #                       Bind the transports
        #--------------------------------------------------------------------------------

        transform.subscriber.xn.bind(input_route.routing_key, publisher.xp)
        subscriber.xn.bind(output_route.routing_key, transform.publisher.xp)
        subscriber.start()


        publisher.publish('test')

        self.assertTrue(self.verified.wait(4))
예제 #26
0
파일: ws.py 프로젝트: OpenSight/IVR
class WSClientTransport(WebSocketClient):
    APP_FACTORY = None

    def __init__(self, url):
        self._close_event = Event()
        # patch socket.sendall to protect it with lock,
        # in order to prevent sending data from multiple greenlets concurrently
        WebSocketClient.__init__(self, url)
        self._app = None
        self._lock = RLock()
        _sendall = self.sock.sendall

        def sendall(data):
            self._lock.acquire()
            try:
                _sendall(data)
            except:
                raise
            finally:
                self._lock.release()
        self.sock.sendall = sendall

    def connect(self):
        super(WSClientTransport, self).connect()
        self._app = self.APP_FACTORY(self)
        log.info("Connected to websocket server {0}".format(self.url))

    def closed(self, code, reason=None):
        app, self._app = self._app, None
        if app:
            app.on_close()
        self._close_event.set()

    def ponged(self, pong):
        pass

    def received_message(self, message):
        log.debug("Received message {0}".format(message))
        if self._app:
            self._app.on_received_packet(STRING(message))
        else:
            log.warning('Websocket client app already closed')

    def send_packet(self, data):
        log.debug("Sending message {0}".format(data))
        self.send(data)

    def force_shutdown(self):
        # called by the upper layer, and no callback will be possible when closed
        self._app = None
        self.close()
        self._close_event.set()
        log.info('Websocket client closed')

    def wait_close(self):
        self._close_event.wait()

    def app(self):
        return self._app
예제 #27
0
 def test(self):
     e = Event()
     waiters = [gevent.spawn(e.wait) for i in range(self.N)]
     gevent.sleep(0.001)
     e.set()
     e.clear()
     for t in waiters:
         t.join()
예제 #28
0
 def __init__(self, player_index):
     self.player_index = player_index
     self.account = Account.authenticate("Proton", "123")
     self.observers = []
     self._channel = Queue(0)
     e = Event()
     e.set()
     self.gdevent = e
예제 #29
0
class EventPersister(StandaloneProcess):

    def on_init(self):
        # Time in between event persists
        self.persist_interval = 1.0

        # Holds received events FIFO
        self.event_queue = Queue()

        # Temporarily holds list of events to persist while datastore operation not yet completed
        self.events_to_persist = None

        # bookkeeping for timeout greenlet
        self._persist_greenlet = None
        self._terminate_persist = Event() # when set, exits the timeout greenlet

        # The event subscriber
        self.event_sub = None

    def on_start(self):
        # Persister thread
        self._persist_greenlet = spawn(self._trigger_func, self.persist_interval)
        log.debug('Publisher Greenlet started in "%s"' % self.__class__.__name__)

        # Event subscription
        self.event_sub = EventSubscriber(pattern=EventSubscriber.ALL_EVENTS, callback=self._on_event)
        self.event_sub.start()

    def on_quit(self):
        # Stop event subscriber
        self.event_sub.stop()

        # tell the trigger greenlet we're done
        self._terminate_persist.set()

        # wait on the greenlet to finish cleanly
        self._persist_greenlet.join(timeout=10)

    def _on_event(self, event, *args, **kwargs):
        self.event_queue.put(event)

    def _trigger_func(self, persist_interval):
        log.debug('Starting event persister thread with persist_interval=%s', persist_interval)

        # Event.wait returns False on timeout (and True when set in on_quit), so we use this to both exit cleanly and do our timeout in a loop
        while not self._terminate_persist.wait(timeout=persist_interval):
            try:
                self.events_to_persist = [self.event_queue.get() for x in xrange(self.event_queue.qsize())]

                self._persist_events(self.events_to_persist)
                self.events_to_persist = None
            except Exception as ex:
                log.exception("Failed to persist received events")
                return False

    def _persist_events(self, event_list):
        if event_list:
            bootstrap.container_instance.event_repository.put_events(event_list)
예제 #30
0
class AvailManager(object):
    SIM_RUNNING = False
    sim_thread = None

    def __init__(self):
        self.updates = []
        self.event = Event()


    def run_sim(self, delay, size):
        roomids = Room.objects.filter(avail=True).values_list('id',flat=True)
        #print 'Room ids: %s ' % roomids
        ids = []
        for i in range(0, len(roomids)):
            ids.append(roomids[i])
            random.shuffle(ids)
        for i in range(0, len(ids)):
            roomg = Room.objects.filter(id=ids[i])
            updateavail(Room.objects.filter(id=ids[i]))
            room = roomg[0]
            self.updates.append(RoomUpdate(room))
            self.event.set()
            self.event.clear()
            if i % size == 0:
                sleep(delay)
        self.SIM_RUNNING = False

    def start_sim(self, delay, size=1):
        if self.SIM_RUNNING:
            kill(self.sim_thread)
        self.updates = []
        Room.objects.all().update(avail=True)
        self.sim_thread = spawn(self.run_sim, delay=delay, size=size)
        self.SIM_RUNNING = True
        return HttpResponse('Started Simulation with delay %d' % delay)

    def stop_sim(self):
        if not self.SIM_RUNNING:
            return HttpResponse('No current simulation')
        kill(self.sim_thread)
        self.SIM_RUNNING = False
        return HttpResponse('Stopped simulation')
    
    def check_avail(self, timestamp):
        if len(self.updates) == 0 or timestamp > self.updates[0].timestamp:
            self.event.wait()
        room_ids = []
        i = len(self.updates) - 1
        while i >= 0:
            i = i - 1
            update = self.updates[i]
            if timestamp <= update.timestamp:
                room_ids.append(update.room_id)
            else:
                break

        return {'timestamp':int(time.time()),
                'rooms':room_ids}
예제 #31
0
class GMatrixClient(MatrixClient):
    """ Gevent-compliant MatrixClient subclass """

    sync_worker: Optional[Greenlet] = None
    message_worker: Optional[Greenlet] = None
    last_sync: float = float("inf")

    def __init__(
        self,
        handle_messages_callback: Callable[[MatrixSyncMessages], bool],
        handle_member_join_callback: Callable[[Room], None],
        base_url: str,
        token: str = None,
        user_id: str = None,
        valid_cert_check: bool = True,
        sync_filter_limit: int = 20,
        cache_level: CACHE = CACHE.ALL,
        http_pool_maxsize: int = 10,
        http_retry_timeout: int = 60,
        http_retry_delay: Callable[[], Iterable[float]] = lambda: repeat(1),
        environment: Environment = Environment.PRODUCTION,
        user_agent: str = None,
    ) -> None:

        self.token: Optional[str] = None
        self.environment = environment
        self.handle_messages_callback = handle_messages_callback
        self._handle_member_join_callback = handle_member_join_callback
        self.response_queue: NotifyingQueue[Tuple[UUID, JSONResponse, datetime]] = NotifyingQueue()
        self.stop_event = Event()

        super().__init__(
            base_url, token, user_id, valid_cert_check, sync_filter_limit, cache_level
        )
        self.api = GMatrixHttpApi(
            base_url,
            token,
            pool_maxsize=http_pool_maxsize,
            retry_timeout=http_retry_timeout,
            retry_delay=http_retry_delay,
            long_paths=("/sync",),
            user_agent=user_agent,
        )
        self.api.validate_certificate(valid_cert_check)

        # Monotonically increasing id to ensure that presence updates are processed in order.
        self._presence_update_ids: Iterator[int] = itertools.count()
        self._worker_pool = gevent.pool.Pool(size=20)
        # Gets incremented every time a sync loop is completed. This is useful since the sync token
        # can remain constant over multiple loops (if no events occur).
        self.sync_progress = SyncProgress(self.response_queue)
        self._sync_filter_id: Optional[int] = None

    @property
    def synced(self) -> Event:
        return self.sync_progress.synced_event

    @property
    def processed(self) -> Event:
        return self.sync_progress.processed_event

    @property
    def sync_iteration(self) -> int:
        return self.sync_progress.sync_iteration

    def create_sync_filter(
        self,
        rooms: Optional[Iterable[Room]] = None,
        not_rooms: Optional[Iterable[Room]] = None,
        limit: Optional[int] = None,
    ) -> Optional[int]:
        """Create a matrix sync filter

        A whitelist and blacklist of rooms can be supplied optionally. If
        no whitelist ist given, all rooms are whitelisted. The blacklist is
        applied on top of the whitelist.

        Ref. https://matrix.org/docs/spec/client_server/r0.6.0#api-endpoints

        Args:
            rooms: whitelist of rooms, if not given all rooms are whitelisted
            not_rooms: blacklist of rooms, applied after the whitelist
            limit: maximum number of messages to return

        """
        if not_rooms is None and rooms is None and limit is None:
            return None

        broadcast_room_filter: Dict[str, Dict] = {
            # Get all presence updates
            "presence": {"types": ["m.presence"]},
            # filter account data
            "account_data": {"not_types": ["*"]},
            # Ignore "message receipts" from all rooms
            "room": {"ephemeral": {"not_types": ["m.receipt"]}},
        }
        if not_rooms:
            negative_rooms = [room.room_id for room in not_rooms]
            broadcast_room_filter["room"].update(
                {
                    # Filter out all unwanted rooms
                    "not_rooms": negative_rooms
                }
            )
        if rooms:
            positive_rooms = [room.room_id for room in rooms]
            broadcast_room_filter["room"].update(
                {
                    # Set all wanted rooms
                    "rooms": positive_rooms
                }
            )

        limit_filter: Dict[str, Any] = {}
        if limit is not None:
            limit_filter = {"room": {"timeline": {"limit": limit}}}

        final_filter = broadcast_room_filter
        merge_dict(final_filter, limit_filter)

        try:
            # 0 is a valid filter ID
            filter_response = self.api.create_filter(self.user_id, final_filter)
            filter_id = filter_response.get("filter_id")
            log.debug("Sync filter created", filter_id=filter_id, filter=final_filter)

        except MatrixRequestError as ex:
            raise TransportError(
                f"Failed to create filter: {final_filter} for user {self.user_id}"
            ) from ex

        return filter_id

    def listen_forever(
        self,
        timeout_ms: int,
        latency_ms: int,
        exception_handler: Callable[[Exception], None] = None,
        bad_sync_timeout: int = 5,
    ) -> None:
        """
        Keep listening for events forever.

        Args:
            timeout_ms: How long to poll the Home Server for before retrying.
            exception_handler: Optional exception handler function which can
                be used to handle exceptions in the caller thread.
            bad_sync_timeout: Base time to wait after an error before retrying.
                Will be increased according to exponential backoff.
        """
        _bad_sync_timeout = bad_sync_timeout

        while not self.stop_event.is_set():
            try:
                # may be killed and raise exception from message_worker
                self._sync(timeout_ms, latency_ms)
                _bad_sync_timeout = bad_sync_timeout
            except MatrixRequestError as e:
                log.warning(
                    "A MatrixRequestError occurred during sync.",
                    node=node_address_from_userid(self.user_id),
                    user_id=self.user_id,
                )
                if e.code >= 500:
                    log.warning(
                        "Problem occurred serverside. Waiting",
                        node=node_address_from_userid(self.user_id),
                        user_id=self.user_id,
                        wait_for=_bad_sync_timeout,
                    )
                    gevent.sleep(_bad_sync_timeout)
                    _bad_sync_timeout = min(_bad_sync_timeout * 2, self.bad_sync_timeout_limit)
                else:
                    raise
            except MatrixHttpLibError:
                log.exception(
                    "A MatrixHttpLibError occurred during sync.",
                    node=node_address_from_userid(self.user_id),
                    user_id=self.user_id,
                )
                if not self.stop_event.is_set():
                    gevent.sleep(_bad_sync_timeout)
                    _bad_sync_timeout = min(_bad_sync_timeout * 2, self.bad_sync_timeout_limit)
            except Exception as e:
                log.exception(
                    "Exception thrown during sync",
                    node=node_address_from_userid(self.user_id),
                    user_id=self.user_id,
                )
                if exception_handler is not None:
                    exception_handler(e)
                else:
                    raise

    def start_listener_thread(
        self, timeout_ms: int, latency_ms: int, exception_handler: Callable = None
    ) -> None:
        """
        Start a listener greenlet to listen for events in the background.

        Args:
            timeout_ms: How long to poll the Home Server for before retrying.
            exception_handler: Optional exception handler function which can
                be used to handle exceptions in the caller thread.
        """
        assert self.sync_worker is None, "Already running"
        # Needs to be reset, otherwise we might run into problems when restarting
        self.last_sync = float("inf")

        self.sync_worker = gevent.spawn(
            self.listen_forever, timeout_ms, latency_ms, exception_handler
        )
        self.sync_worker.name = f"GMatrixClient.sync_worker user_id:{self.user_id}"
        self.message_worker = gevent.spawn(
            self._handle_message, self.response_queue, self.stop_event
        )
        self.message_worker.name = f"GMatrixClient.message_worker user_id:{self.user_id}"
        self.message_worker.link_exception(lambda g: self.sync_worker.kill(g.exception))

        # FIXME: This is just a temporary hack, this adds a race condition of the user pressing
        #     Ctrl-C before this is run, and Raiden newer shutting down.
        self.stop_event.clear()

    def stop_listener_thread(self) -> None:
        """ Kills sync_thread greenlet before joining it """
        # when stopping, `kill` will cause the `self.api.sync` call in _sync
        # to raise a connection error. This flag will ensure it exits gracefully then
        self.stop_event.set()

        if self.sync_worker:
            self.sync_worker.kill()
            log.debug(
                "Waiting on sync greenlet",
                node=node_address_from_userid(self.user_id),
                user_id=self.user_id,
            )
            exited = gevent.joinall({self.sync_worker}, timeout=SHUTDOWN_TIMEOUT, raise_error=True)
            if not exited:
                raise RuntimeError("Timeout waiting on sync greenlet during transport shutdown.")
            self.sync_worker.get()

        if self.message_worker is not None:
            log.debug(
                "Waiting on handle greenlet",
                node=node_address_from_userid(self.user_id),
                current_user=self.user_id,
            )
            exited = gevent.joinall(
                {self.message_worker}, timeout=SHUTDOWN_TIMEOUT, raise_error=True
            )
            if not exited:
                raise RuntimeError("Timeout waiting on handle greenlet during transport shutdown.")
            self.message_worker.get()

        log.debug(
            "Listener greenlet exited",
            node=node_address_from_userid(self.user_id),
            user_id=self.user_id,
        )
        self.sync_worker = None
        self.message_worker = None

    def stop(self) -> None:
        self.stop_listener_thread()
        self.sync_token = None
        self.rooms: Dict[str, Room] = {}
        self._worker_pool.join(raise_error=True)

    def logout(self) -> None:
        super().logout()
        self.api.session.close()

    def search_user_directory(self, term: str) -> List[User]:
        """
        Search user directory for a given term, returning a list of users
        Args:
            term: term to be searched for
        Returns:
            user_list: list of users returned by server-side search
        """
        try:
            response = self.api._send("POST", "/user_directory/search", {"search_term": term})
        except MatrixRequestError as ex:
            if ex.code >= 500:
                log.error(
                    "Ignoring Matrix error in `search_user_directory`",
                    exc_info=ex,
                    term=term,
                )
                return list()
            else:
                raise ex
        try:
            return [
                User(self.api, _user["user_id"], _user["display_name"])
                for _user in response["results"]
            ]
        except KeyError:
            return list()

    def set_presence_state(self, state: str) -> Dict:
        return self.api._send(
            "PUT", f"/presence/{quote(self.user_id)}/status", {"presence": state}
        )

    def _mkroom(self, room_id: str) -> Room:
        """ Uses a geventified Room subclass """
        if room_id not in self.rooms:
            self.rooms[room_id] = Room(self, room_id)
        room = self.rooms[room_id]
        if not room.canonical_alias:
            room.update_local_alias()
        return room

    def get_user_presence(self, user_id: str) -> Optional[str]:
        return self.api.get_presence(user_id).get("presence")

    def create_room(
        self, alias: str = None, is_public: bool = False, invitees: List[str] = None, **kwargs: Any
    ) -> MatrixRoom:
        """Create a new room on the homeserver.

        Args:
            alias (str): The canonical_alias of the room.
            is_public (bool):  The public/private visibility of the room.
            invitees (str[]): A set of user ids to invite into the room.

        Returns:
            Room

        Raises:
            MatrixRequestError
        """
        response = self.api.create_room(alias, is_public, invitees, **kwargs)
        return self._mkroom(response["room_id"])

    def blocking_sync(self, timeout_ms: int, latency_ms: int) -> None:
        """Perform a /sync and process the response synchronously."""
        self._sync(timeout_ms=timeout_ms, latency_ms=latency_ms)

        pending_queue = []
        while len(self.response_queue) > 0:
            _, response, _ = self.response_queue.get()
            pending_queue.append(response)

        assert all(pending_queue), "Sync returned, None and empty are invalid values."

        self._handle_responses(pending_queue)

    def _sync(self, timeout_ms: int, latency_ms: int) -> None:
        """ Reimplements MatrixClient._sync """
        log.debug(
            "Sync called",
            node=node_address_from_userid(self.user_id),
            user_id=self.user_id,
            sync_iteration=self.sync_iteration,
            sync_filter_id=self._sync_filter_id,
            last_sync_time=self.last_sync,
        )

        time_before_sync = time.monotonic()
        time_since_last_sync_in_seconds = time_before_sync - self.last_sync

        # If it takes longer than `timeout_ms + latency_ms` to call `_sync`
        # again, we throw an exception.  The exception is only thrown when in
        # development mode.
        timeout_in_seconds = (timeout_ms + latency_ms) // 1_000
        timeout_reached = (
            time_since_last_sync_in_seconds >= timeout_in_seconds
            and self.environment == Environment.DEVELOPMENT
        )
        # The second sync is the first full sync and can be slow. This is
        # acceptable, we only want to know if we fail to sync quickly
        # afterwards.
        # As the runtime is evaluated in the subsequent run, we only run this
        # after the second iteration is finished.
        if timeout_reached:
            if IDLE:
                IDLE.log()

            raise MatrixSyncMaxTimeoutReached(
                f"Time between syncs exceeded timeout:  "
                f"{time_since_last_sync_in_seconds}s > {timeout_in_seconds}s. {IDLE}"
            )

        log.debug(
            "Calling api.sync",
            node=node_address_from_userid(self.user_id),
            user_id=self.user_id,
            sync_iteration=self.sync_iteration,
            time_since_last_sync_in_seconds=time_since_last_sync_in_seconds,
        )
        self.last_sync = time_before_sync
        response = self.api.sync(
            since=self.sync_token, timeout_ms=timeout_ms, filter=self._sync_filter_id
        )
        time_after_sync = time.monotonic()

        log.debug(
            "api.sync returned",
            node=node_address_from_userid(self.user_id),
            user_id=self.user_id,
            sync_iteration=self.sync_iteration,
            time_after_sync=time_after_sync,
            time_taken=time_after_sync - time_before_sync,
        )

        if response:
            token = uuid4()

            log.debug(
                "Sync returned",
                node=node_address_from_userid(self.user_id),
                token=token,
                elapsed=time_after_sync - time_before_sync,
                current_user=self.user_id,
                presence_events_qty=len(response["presence"]["events"]),
                to_device_events_qty=len(response["to_device"]["events"]),
                rooms_invites_qty=len(response["rooms"]["invite"]),
                rooms_leaves_qty=len(response["rooms"]["leave"]),
                rooms_joined_member_count=sum(
                    room["summary"].get("m.joined_member_count", 0)
                    for room in response["rooms"]["join"].values()
                ),
                rooms_invited_member_count=sum(
                    room["summary"].get("m.invited_member_count", 0)
                    for room in response["rooms"]["join"].values()
                ),
                rooms_join_state_qty=sum(
                    len(room["state"]) for room in response["rooms"]["join"].values()
                ),
                rooms_join_timeline_events_qty=sum(
                    len(room["timeline"]["events"]) for room in response["rooms"]["join"].values()
                ),
                rooms_join_state_events_qty=sum(
                    len(room["state"]["events"]) for room in response["rooms"]["join"].values()
                ),
                rooms_join_ephemeral_events_qty=sum(
                    len(room["ephemeral"]["events"]) for room in response["rooms"]["join"].values()
                ),
                rooms_join_account_data_events_qty=sum(
                    len(room["account_data"]["events"])
                    for room in response["rooms"]["join"].values()
                ),
            )

            # Updating the sync token should only be done after the response is
            # saved in the queue, otherwise the data can be lost in a stop/start.
            self.response_queue.put((token, response, datetime.now()))
            self.sync_token = response["next_batch"]
            self.sync_progress.set_synced(token)

    def _handle_message(
        self,
        response_queue: NotifyingQueue[Tuple[UUID, JSONResponse, datetime]],
        stop_event: Event,
    ) -> None:
        """Worker to process network messages from the asynchronous transport.

        Note that this worker will process the messages in the order of
        delivery. However, the underlying protocol may not guarantee that
        messages are delivered in-order in which they were sent. The transport
        layer has to implement retries to guarantee that a message is
        eventually processed. This introduces a cost in terms of latency.
        """
        while True:
            gevent.joinall({response_queue, stop_event}, count=1, raise_error=True)

            # Iterating over the Queue and adding to a separated list to
            # implement delivery at-least-once semantics. At-most-once would
            # also be acceptable because of message retries, however it has the
            # potential of introducing latency.
            #
            # The Queue's iterator cannot be used because it defaults do `get`.
            currently_queued_response_tokens = list()
            currently_queued_responses = list()
            for token, response, received_at in response_queue.queue.queue:
                assert response is not None, "None is not a valid value for a Matrix response."

                log.debug(
                    "Handling Matrix response",
                    token=token,
                    node=node_address_from_userid(self.user_id),
                    current_size=len(response_queue),
                    processing_lag=datetime.now() - received_at,
                )
                currently_queued_response_tokens.append(token)
                currently_queued_responses.append(response)

            if stop_event.is_set():
                log.debug(
                    "Handling worker exiting, stop is set",
                    node=node_address_from_userid(self.user_id),
                )
                return
            time_before_processing = time.monotonic()
            self._handle_responses(currently_queued_responses)
            time_after_processing = time.monotonic()
            log.debug(
                "Processed queued Matrix responses",
                node=node_address_from_userid(self.user_id),
                elapsed=time_after_processing - time_before_processing,
            )

            # Pop the processed messages, this relies on the fact the queue is
            # ordered to pop the correct messages. If the process is killed
            # right before this call, on the next transport start the same
            # message will be processed again, that is why this is
            # at-least-once semantics.
            for _ in currently_queued_responses:
                response_queue.get(block=False)

            self.sync_progress.set_processed(currently_queued_response_tokens)

    def _handle_responses(self, currently_queued_responses: List[JSONResponse]) -> None:

        all_messages: MatrixSyncMessages = []

        for response in currently_queued_responses:
            for presence_update in response["presence"]["events"]:
                for callback in list(self.presence_listeners.values()):
                    callback(presence_update, next(self._presence_update_ids))

            for to_device_message in response["to_device"]["events"]:
                for listener in self.listeners[:]:
                    if listener["event_type"] == "to_device":
                        listener["callback"](to_device_message)

            # Add toDevice messages to message queue
            if response["to_device"]["events"]:
                all_messages.append(
                    (
                        None,
                        response["to_device"]["events"],
                    )
                )

            for room_id, invite_room in response["rooms"]["invite"].items():
                for listener in self.invite_listeners[:]:
                    listener(room_id, invite_room["invite_state"])

            for room_id, left_room in response["rooms"]["leave"].items():
                for listener in self.left_listeners[:]:
                    listener(room_id, left_room)
                if room_id in self.rooms:
                    del self.rooms[room_id]

            for room_id, sync_room in response["rooms"]["join"].items():
                if room_id not in self.rooms:
                    self._mkroom(room_id)

                room = self.rooms[room_id]
                room.prev_batch = sync_room["timeline"]["prev_batch"]
                room_members_count = len(room._members)

                for event in sync_room["state"]["events"]:
                    event["room_id"] = room_id
                    room._process_state_event(event)
                for event in sync_room["timeline"]["events"]:
                    event["room_id"] = room_id
                    room._put_event(event)

                # number of members changed. Verify validity of room
                if room_members_count != len(room._members):
                    self._handle_member_join_callback(room)
                all_messages.append(
                    (
                        room,
                        [
                            message
                            for message in sync_room["timeline"]["events"]
                            if message["type"] == "m.room.message"
                        ],
                    )
                )

                for event in sync_room["ephemeral"]["events"]:
                    event["room_id"] = room_id
                    room._put_ephemeral_event(event)

                    for listener in self.ephemeral_listeners:
                        should_call = (
                            listener["event_type"] is None
                            or listener["event_type"] == event["type"]
                        )
                        if should_call:
                            listener["callback"](event)

        if len(all_messages) > 0:
            self.handle_messages_callback(all_messages)

    def set_access_token(self, user_id: str, token: Optional[str]) -> None:
        self.user_id = user_id
        self.token = self.api.token = token

    def set_sync_filter_id(self, sync_filter_id: Optional[int]) -> Optional[int]:
        """ Sets the sync filter to the given id and returns previous filters id """
        prev_id = self._sync_filter_id
        self._sync_filter_id = sync_filter_id
        return prev_id
예제 #32
0
class TestChannelInt(IonIntegrationTestCase):
    def setUp(self):
        self.patch_cfg(
            'pyon.ion.exchange.CFG', {
                'container': {
                    'messaging': {
                        'server': {
                            'primary': 'amqp',
                            'priviledged': None
                        }
                    }
                }
            })
        self._start_container()

    #@skip('Not working consistently on buildbot')
    def test_consume_one_message_at_a_time(self):
        # end to end test for CIDEVCOI-547 requirements
        #    - Process P1 is producing one message every 5 seconds
        #    - Process P2 is producing one other message every 3 seconds
        #    - Process S creates a auto-delete=False queue without a consumer and without a binding
        #    - Process S binds this queue through a pyon.net or container API call to the topic of process P1
        #    - Process S waits a bit
        #    - Process S checks the number of messages in the queue
        #    - Process S creates a consumer, takes one message off the queue (non-blocking) and destroys the consumer
        #    - Process S waits a bit (let messages accumulate)
        #    - Process S creates a consumer, takes a message off and repeates it until no messges are left (without ever blocking) and destroys the consumer
        #    - Process S waits a bit (let messages accumulate)
        #    - Process S creates a consumer, takes a message off and repeates it until no messges are left (without ever blocking). Then requeues the last message and destroys the consumer
        #    - Process S creates a consumer, takes one message off the queue (non-blocking) and destroys the consumer.
        #    - Process S sends prior message to its queue (note: may be tricky without a subscription to yourself)
        #    - Process S changes the binding of queue to P1 and P2
        #    - Process S removes all bindings of queue
        #    - Process S deletes the queue
        #    - Process S exists without any residual resources in the broker
        #    - Process P1 and P1 get terminated without any residual resources in the broker
        #
        #    * Show this works with the ACK or no-ACK mode
        #    * Do the above with semi-abstracted calles (some nicer boilerplate)

        def every_five():
            p = self.container.node.channel(PublisherChannel)
            p._send_name = NameTrio(bootstrap.get_sys_name(), 'routed.5')
            counter = 0

            while not self.publish_five.wait(timeout=5):
                p.send('5,' + str(counter))
                counter += 1

        def every_three():
            p = self.container.node.channel(PublisherChannel)
            p._send_name = NameTrio(bootstrap.get_sys_name(), 'routed.3')
            counter = 0

            while not self.publish_three.wait(timeout=3):
                p.send('3,' + str(counter))
                counter += 1

        self.publish_five = Event()
        self.publish_three = Event()
        self.five_events = Queue()
        self.three_events = Queue()

        gl_every_five = spawn(every_five)
        gl_every_three = spawn(every_three)

        def listen(lch):
            """
            The purpose of the this listen method is to trigger waits in code below.
            By setting up a listener that subscribes to both 3 and 5, and putting received
            messages into the appropriate gevent-queues client side, we can assume that
            the channel we're actually testing with get_stats etc has had the message delivered
            too.
            """
            lch._queue_auto_delete = False
            lch.setup_listener(
                NameTrio(bootstrap.get_sys_name(), 'alternate_listener'),
                'routed.3')
            lch._bind('routed.5')
            lch.start_consume()

            while True:
                try:
                    newchan = lch.accept()
                    m, h, d = newchan.recv()
                    count = m.rsplit(',', 1)[-1]
                    if m.startswith('5,'):
                        self.five_events.put(int(count))
                        newchan.ack(d)
                    elif m.startswith('3,'):
                        self.three_events.put(int(count))
                        newchan.ack(d)
                    else:
                        raise StandardError("unknown message: %s" % m)

                except ChannelClosedError:
                    break

        lch = self.container.node.channel(SubscriberChannel)
        gl_listen = spawn(listen, lch)

        def do_cleanups(gl_e5, gl_e3, gl_l, lch):
            self.publish_five.set()
            self.publish_three.set()
            gl_e5.join(timeout=5)
            gl_e3.join(timeout=5)

            lch.stop_consume()
            lch._destroy_queue()
            lch.close()
            gl_listen.join(timeout=5)

        self.addCleanup(do_cleanups, gl_every_five, gl_every_three, gl_listen,
                        lch)

        ch = self.container.node.channel(RecvChannel)
        ch._recv_name = NameTrio(bootstrap.get_sys_name(), 'test_queue')
        ch._queue_auto_delete = False

        def cleanup_channel(thech):
            thech._destroy_queue()
            thech.close()

        self.addCleanup(cleanup_channel, ch)

        # declare exchange and queue, no binding yet
        ch._declare_exchange(ch._recv_name.exchange)
        ch._declare_queue(ch._recv_name.queue)
        ch._purge()

        # do binding to 5 pub only
        ch._bind('routed.5')

        # wait for one message
        self.five_events.get(timeout=10)

        # ensure 1 message, 0 consumer
        self.assertTupleEqual((1, 0), ch.get_stats())

        # start a consumer
        ch.start_consume()
        time.sleep(0.1)
        self.assertEquals(
            ch._recv_queue.qsize(),
            1)  # should have been delivered to the channel, waiting for us now

        # receive one message with instant timeout
        m, h, d = ch.recv(timeout=0)
        self.assertEquals(m, "5,0")
        ch.ack(d)

        # we have no more messages, should instantly fail
        self.assertRaises(PQueue.Empty, ch.recv, timeout=0)

        # stop consumer
        ch.stop_consume()

        # wait until next 5 publish event
        self.five_events.get(timeout=10)

        # start consumer again, empty queue
        ch.start_consume()
        time.sleep(0.1)
        while True:
            try:
                m, h, d = ch.recv(timeout=0)
                self.assertTrue(m.startswith('5,'))
                ch.ack(d)
            except PQueue.Empty:
                ch.stop_consume()
                break

        # wait for new message
        self.five_events.get(timeout=10)

        # consume and requeue
        ch.start_consume()
        time.sleep(0.1)
        m, h, d = ch.recv(timeout=0)
        self.assertTrue(m.startswith('5,'))
        ch.reject(d, requeue=True)

        # rabbit appears to deliver this later on, only when we've got another message in it
        # wait for another message publish
        num = self.five_events.get(timeout=10)
        self.assertEquals(num, 3)
        time.sleep(0.1)

        expect = ["5,2", "5,3"]
        while True:
            try:
                m, h, d = ch.recv(timeout=0)
                self.assertTrue(m.startswith('5,'))
                self.assertEquals(m, expect.pop(0))

                ch.ack(d)
            except PQueue.Empty:
                ch.stop_consume()
                self.assertListEqual(expect, [])
                break

        # let's change the binding to the 3 now, empty the testqueue first (artifact of test)
        while not self.three_events.empty():
            self.three_events.get(timeout=0)

        # we have to keep the exchange around - it will likely autodelete.
        ch2 = self.container.node.channel(RecvChannel)
        ch2.setup_listener(NameTrio(bootstrap.get_sys_name(), "another_queue"))

        ch._destroy_binding()
        ch._bind('routed.3')

        ch2._destroy_queue()
        ch2.close()

        self.three_events.get(timeout=10)
        ch.start_consume()
        time.sleep(0.1)
        self.assertEquals(ch._recv_queue.qsize(), 1)

        m, h, d = ch.recv(timeout=0)
        self.assertTrue(m.startswith('3,'))
        ch.ack(d)

        # wait for a new 3 to reject
        self.three_events.get(timeout=10)
        time.sleep(0.1)

        m, h, d = ch.recv(timeout=0)
        ch.reject(d, requeue=True)

        # recycle consumption, should get the requeued message right away?
        ch.stop_consume()
        ch.start_consume()
        time.sleep(0.1)

        self.assertEquals(ch._recv_queue.qsize(), 1)

        m2, h2, d2 = ch.recv(timeout=0)
        self.assertEquals(m, m2)

        ch.stop_consume()
예제 #33
0
class HttpHandler(object):
    def __init__(self, environ, start_response):
        global reply_channels

        # Make a name for our reply channel
        self.reply_channel = default_channel_layer.new_channel(
            u"http.response." + default_channel_layer.client_prefix + "!")
        self.last_keepalive = time.time()
        #self.factory.reply_protocols[self.reply_channel] = self

        clean_headers = []
        for k, v in environ.iteritems():
            if (k.startswith("HTTP_")):
                clean_headers.append([k[5:].lower(), v])
            elif k.lower().startswith("content"):
                clean_headers.append([k.lower().replace("_", "-"), v])

        self.data = []
        self.ev = Event()
        self.rc_name = self.reply_channel.split("!")[1]
        reply_channels[self.rc_name] = self

        body = environ['wsgi.input'].read()
        #print clean_headers
        default_channel_layer.send(
            "http.request",
            {
                "reply_channel": self.reply_channel,
                # TODO: Correctly say if it's 1.1 or 1.0
                "http_version": "1.1",
                "method": environ['REQUEST_METHOD'],
                "path": environ['PATH_INFO'],
                "root_path": environ['SCRIPT_NAME'],
                "scheme": "http",
                "query_string": environ['QUERY_STRING'],
                "headers": clean_headers,
                "body": body,
                "client": environ['REMOTE_ADDR'],
                "server": environ['SERVER_NAME'],
            })

        self.response = []
        while (True):
            self.ev.wait()
            self.ev.clear()
            while len(self.data) > 0:
                d = self.data.pop(0)
                v = d['more_content']
                #print v
                if 'status' in d:
                    start_response(to_status(d['status']), d['headers'])

                if 'content' in d:
                    self.response.append(d['content'])

                if v == False:
                    #print d
                    return

    def notify(self, content):
        #print "inNotify"
        #print content
        self.data.append(content)
        self.ev.set()

    def clean_up(self):
        del reply_channels[self.rc_name]
예제 #34
0
파일: download.py 프로젝트: dl9rdz/tawhiri
class DatasetDownloader(object):
    _queue_item_type = namedtuple("queue_item",
                                  ("hour", "sleep_until", "filename",
                                   "expect_pressures", "bad_downloads"))

    filename_pattern = \
            "gfs.t{ds_hour}z.pgrb2{pressure_flag}.0p50.f{axis_hour:03}"

    def __init__(self,
                 directory,
                 ds_time,
                 timeout=120,
                 first_file_timeout=600,
                 bad_download_retry_limit=3,
                 write_dataset=True,
                 write_gribmirror=True,
                 deadline=None,
                 dataset_host="ftp.ncep.noaa.gov",
                 dataset_path="/pub/data/nccf/com/gfs/prod/gfs.{0}/"):

        # set these ASAP for close() via __del__ if __init__ raises something
        self.success = False
        self._dataset = None
        self._gribmirror = None
        self._tmp_directory = None

        assert ds_time.hour in (0, 6, 12, 18)
        assert ds_time.minute == ds_time.second == ds_time.microsecond == 0

        if not (write_dataset or write_gribmirror):
            raise ValueError("Choose write_datset or write_gribmirror "
                             "(or both)")

        if deadline is None:
            deadline = max(datetime.now() + timedelta(hours=2),
                           ds_time + timedelta(hours=9, minutes=30))

        self.directory = directory
        self.ds_time = ds_time

        self.timeout = timeout
        self.first_file_timeout = first_file_timeout
        self.write_dataset = write_dataset
        self.write_gribmirror = write_gribmirror
        self.bad_download_retry_limit = bad_download_retry_limit

        self.deadline = deadline
        self.dataset_host = dataset_host
        self.dataset_path = dataset_path

        self.have_first_file = False

        self.files_complete = 0
        self.files_count = 0
        self.completed = Event()

        ds_time_str = self.ds_time.strftime("%Y%m%d%H")
        self.remote_directory = dataset_path.format(ds_time_str)

        self._greenlets = Group()
        self.unpack_lock = RLock()

        # Items in the queue are
        #   (hour, sleep_until, filename, ...)
        # so they sort by hour, and then if a not-found adds a delay to
        # a specific file, files from that hour without the delay
        # are tried first
        self._files = PriorityQueue()

        # areas in self.dataset.array are considered 'undefined' until
        #   self.checklist[index[:3]] is True, since unpack_grib may
        #   write to them, and then abort via ValueError before marking
        #   updating the checklist if the file turns out later to be bad

        # the checklist also serves as a sort of final sanity check:
        #   we also have "does this file contain all the records we think it
        #   should" checklists; see Worker._download_file

        self._checklist = make_checklist()

    def open(self):
        logger.info("downloader: opening files for dataset %s", self.ds_time)

        self._tmp_directory = \
                tempfile.mkdtemp(dir=self.directory, prefix="download.")
        os.chmod(self._tmp_directory, 0o775)
        logger.debug("Temporary directory is %s", self._tmp_directory)

        if self.write_dataset:
            self._dataset = \
                Dataset(self.ds_time, directory=self._tmp_directory, new=True)

        if self.write_gribmirror:
            fn = Dataset.filename(self.ds_time,
                                  directory=self._tmp_directory,
                                  suffix=Dataset.SUFFIX_GRIBMIRROR)
            logger.debug("Opening gribmirror (truncate and write) %s %s",
                         self.ds_time, fn)
            self._gribmirror = open(fn, "w+")

    def download(self):
        logger.info("download of %s starting", self.ds_time)

        ttl, addresses = resolve_ipv4(self.dataset_host)
        logger.debug("Resolved to %s IPs", len(addresses))

        addresses = [inet_ntoa(x) for x in addresses]

        total_timeout = self.deadline - datetime.now()
        total_timeout_secs = total_timeout.total_seconds()
        if total_timeout_secs < 0:
            raise ValueError("Deadline already passed")
        else:
            logger.debug("Deadline in %s", total_timeout)

        self._add_files()
        self._run_workers(addresses, total_timeout_secs)

        if not self.completed.is_set():
            raise ValueError("timed out")

        if not self._checklist.all():
            raise ValueError("incomplete: records missing")

        self.success = True
        logger.debug("downloaded %s successfully", self.ds_time)

    def _add_files(self):
        ds_hr_str = self.ds_time.strftime("%H")
        pressure_groups = (("", Dataset.pressures_pgrb2f),
                           ("b", Dataset.pressures_pgrb2bf))

        for axis_hour in Dataset.axes.hour:
            for pressure_flag, expect_pr in pressure_groups:
                fn = self.filename_pattern.format(ds_hour=ds_hr_str,
                                                  pressure_flag=pressure_flag,
                                                  axis_hour=axis_hour)
                qi = self._queue_item_type(axis_hour, 0, fn, expect_pr, 0)
                self._files.put(qi)
                self.files_count += 1

        logger.info("Need to download %s files", self.files_count)

    def _run_workers(self, addresses, total_timeout_secs):
        logger.debug("Spawning %s workers", len(addresses) * 2)

        # don't ask _join_all to raise the first exception it catches
        # if we're already raising something in the except block
        raising = False

        try:
            for worker_id, address in enumerate(addresses * 2):
                w = DownloadWorker(self, worker_id, address)
                w.start()
                w.link()
                self._greenlets.add(w)

            # worker unhandled exceptions are raised in this greenlet
            # via link(). They can appear in completed.wait and
            # greenlets.kill(block=True) only (the only times that this
            # greenlet will yield)
            self.completed.wait(timeout=total_timeout_secs)

        except:
            # includes LinkedCompleted - a worker should not exit cleanly
            # until we .kill them below
            logger.debug("_run_workers catch %s (will reraise)",
                         sys.exc_info()[1])
            raising = True
            raise

        finally:
            # don't leak workers.
            self._join_all(raise_exception=(not raising))

    def _join_all(self, raise_exception=False):
        # we need the loop to run to completion and so have it catch and
        # hold or discard exceptions for later.
        # track the first exception caught and re-raise that
        exc_info = None

        while len(self._greenlets):
            try:
                self._greenlets.kill(block=True)
            except greenlet.LinkedCompleted:
                # now that we've killed workers, these are expected.
                # ignore.
                pass
            except greenlet.LinkedFailed as e:
                if exc_info is None and raise_exception:
                    logger.debug("_join_all catch %s " "(will reraise)", e)
                    exc_info = sys.exc_info()
                else:
                    logger.debug(
                        "_join_all discarding %s "
                        "(already have exc)", e)

        if exc_info is not None:
            try:
                reraise(exc_info[1], None, exc_info[2])
            finally:
                # avoid circular reference
                del exc_info

    def _file_complete(self):
        self.files_complete += 1
        self.have_first_file = True

        if self.files_complete == self.files_count:
            self.completed.set()

        logger.info("progress %s/%s %s%%", self.files_complete,
                    self.files_count,
                    self.files_complete / self.files_count * 100)

    def close(self, move_files=None):
        if move_files is None:
            move_files = self.success

        if self._dataset is not None or self._gribmirror is not None or \
                self._tmp_directory is not None:
            if move_files:
                logger.info("moving downloaded files")
            else:
                logger.info("deleting failed download files")

        if self._dataset is not None:
            self._dataset.close()
            self._dataset = None
            if move_files:
                self._move_file()
            else:
                self._delete_file()

        if self._gribmirror is not None:
            self._gribmirror.close()
            self._gribmirror = None
            if move_files:
                self._move_file(Dataset.SUFFIX_GRIBMIRROR)
            else:
                self._delete_file(Dataset.SUFFIX_GRIBMIRROR)

        if self._tmp_directory is not None:
            self._remove_download_directory()
            self._tmp_directory = None

    def __del__(self):
        self.close()

    def _remove_download_directory(self):
        l = os.listdir(self._tmp_directory)
        if l:
            logger.warning("cleaning %s unknown file%s in temporary directory",
                           len(l), '' if len(l) == 1 else 's')

        logger.debug("removing temporary directory")
        shutil.rmtree(self._tmp_directory)

    def _move_file(self, suffix=''):
        fn1 = Dataset.filename(self.ds_time,
                               directory=self._tmp_directory,
                               suffix=suffix)
        fn2 = Dataset.filename(self.ds_time,
                               directory=self.directory,
                               suffix=suffix)
        logger.debug("renaming %s to %s", fn1, fn2)
        os.rename(fn1, fn2)

    def _delete_file(self, suffix=''):
        fn = Dataset.filename(self.ds_time,
                              directory=self._tmp_directory,
                              suffix=suffix)
        logger.warning("deleting %s", fn)
        os.unlink(fn)
예제 #35
0
class UDPTransport:
    def __init__(self, discovery, udpsocket, throttle_policy, config):
        # these values are initialized by the start method
        self.queueids_to_queues: typing.Dict
        self.raiden: 'RaidenService'

        self.discovery = discovery
        self.config = config

        self.retry_interval = config['retry_interval']
        self.retries_before_backoff = config['retries_before_backoff']
        self.nat_keepalive_retries = config['nat_keepalive_retries']
        self.nat_keepalive_timeout = config['nat_keepalive_timeout']
        self.nat_invitation_timeout = config['nat_invitation_timeout']

        self.event_stop = Event()

        self.greenlets = list()
        self.addresses_events = dict()

        self.messageids_to_asyncresults = dict()

        # Maps the addresses to a dict with the latest nonce (using a dict
        # because python integers are immutable)
        self.nodeaddresses_to_nonces = dict()

        cache = cachetools.TTLCache(
            maxsize=50,
            ttl=CACHE_TTL,
        )
        cache_wrapper = cachetools.cached(cache=cache)
        self.get_host_port = cache_wrapper(discovery.get)

        self.throttle_policy = throttle_policy
        self.server = DatagramServer(udpsocket, handle=self._receive)

    def start(
        self,
        raiden: 'RaidenService',
        queueids_to_queues: typing.List[SendMessageEvent],
    ):
        self.raiden = raiden
        self.queueids_to_queues = dict()

        # server.stop() clears the handle. Since this may be a restart the
        # handle must always be set
        self.server.set_handle(self._receive)

        for (recipient, queue_name), queue in queueids_to_queues.items():
            encoded_queue = list()

            for sendevent in queue:
                message = message_from_sendevent(sendevent, raiden.address)
                raiden.sign(message)
                encoded = message.encode()

                encoded_queue.append((encoded, sendevent.message_identifier))

            self.init_queue_for(recipient, queue_name, encoded_queue)

        self.server.start()

    def stop_and_wait(self):
        # Stop handling incoming packets, but don't close the socket. The
        # socket can only be safely closed after all outgoing tasks are stopped
        self.server.stop_accepting()

        # Stop processing the outgoing queues
        self.event_stop.set()
        gevent.wait(self.greenlets)

        # All outgoing tasks are stopped. Now it's safe to close the socket. At
        # this point there might be some incoming message being processed,
        # keeping the socket open is not useful for these.
        self.server.stop()

        # Calling `.close()` on a gevent socket doesn't actually close the underlying os socket
        # so we do that ourselves here.
        # See: https://github.com/gevent/gevent/blob/master/src/gevent/_socket2.py#L208
        # and: https://groups.google.com/forum/#!msg/gevent/Ro8lRra3nH0/ZENgEXrr6M0J
        try:
            self.server._socket.close()  # pylint: disable=protected-access
        except socket.error:
            pass

        # Set all the pending results to False
        for async_result in self.messageids_to_asyncresults.values():
            async_result.set(False)

    def get_health_events(self, recipient):
        """ Starts a healthcheck task for `recipient` and returns a
        HealthEvents with locks to react on its current state.
        """
        if recipient not in self.addresses_events:
            self.start_health_check(recipient)

        return self.addresses_events[recipient]

    def start_health_check(self, recipient):
        """ Starts a task for healthchecking `recipient` if there is not
        one yet.
        """
        if recipient not in self.addresses_events:
            ping_nonce = self.nodeaddresses_to_nonces.setdefault(
                recipient,
                {'nonce': 0},  # HACK: Allows the task to mutate the object
            )

            events = healthcheck.HealthEvents(
                event_healthy=Event(),
                event_unhealthy=Event(),
            )

            self.addresses_events[recipient] = events

            self.greenlets.append(
                gevent.spawn(
                    healthcheck.healthcheck,
                    self,
                    recipient,
                    self.event_stop,
                    events.event_healthy,
                    events.event_unhealthy,
                    self.nat_keepalive_retries,
                    self.nat_keepalive_timeout,
                    self.nat_invitation_timeout,
                    ping_nonce,
                ))

    def init_queue_for(
        self,
        recipient: typing.Address,
        queue_name: bytes,
        items: typing.List[QueueItem_T],
    ) -> Queue_T:
        """ Create the queue identified by the pair `(recipient, queue_name)`
        and initialize it with `items`.
        """
        queueid = (recipient, queue_name)
        queue = self.queueids_to_queues.get(queueid)
        assert queue is None

        queue = NotifyingQueue(items=items)
        self.queueids_to_queues[queueid] = queue

        events = self.get_health_events(recipient)

        self.greenlets.append(
            gevent.spawn(
                single_queue_send,
                self,
                recipient,
                queue,
                self.event_stop,
                events.event_healthy,
                events.event_unhealthy,
                self.retries_before_backoff,
                self.retry_interval,
                self.retry_interval * 10,
            ))

        log.debug(
            'new queue created for',
            node=pex(self.raiden.address),
            token=pex(queue_name),
            to=pex(recipient),
        )

        return queue

    def get_queue_for(
        self,
        recipient: typing.Address,
        queue_name: bytes,
    ) -> Queue_T:
        """ Return the queue identified by the pair `(recipient, queue_name)`.

        If the queue doesn't exist it will be instantiated.
        """
        queueid = (recipient, queue_name)
        queue = self.queueids_to_queues.get(queueid)

        if queue is None:
            items = ()
            queue = self.init_queue_for(recipient, queue_name, items)

        return queue

    def send_async(
        self,
        recipient: typing.Address,
        queue_name: bytes,
        message: 'Message',
    ):
        """ Send a new ordered message to recipient.

        Messages that use the same `queue_name` are ordered.
        """

        if not is_binary_address(recipient):
            raise ValueError('Invalid address {}'.format(pex(recipient)))

        # These are not protocol messages, but transport specific messages
        if isinstance(message, (Delivered, Ping, Pong)):
            raise ValueError('Do not use send for {} messages'.format(
                message.__class__.__name__))

        messagedata = message.encode()
        if len(messagedata) > UDP_MAX_MESSAGE_SIZE:
            raise ValueError('message size exceeds the maximum {}'.format(
                UDP_MAX_MESSAGE_SIZE))

        # message identifiers must be unique
        message_id = message.message_identifier

        # ignore duplicates
        if message_id not in self.messageids_to_asyncresults:
            self.messageids_to_asyncresults[message_id] = AsyncResult()

            queue = self.get_queue_for(recipient, queue_name)
            queue.put((messagedata, message_id))

            log.debug(
                'MESSAGE QUEUED',
                node=pex(self.raiden.address),
                queue_name=queue_name,
                to=pex(recipient),
                message=message,
            )

    def maybe_send(self, recipient: typing.Address, message: Message):
        """ Send message to recipient if the transport is running. """

        if not is_binary_address(recipient):
            raise InvalidAddress('Invalid address {}'.format(pex(recipient)))

        messagedata = message.encode()
        host_port = self.get_host_port(recipient)

        self.maybe_sendraw(host_port, messagedata)

    def maybe_sendraw_with_result(
        self,
        recipient: typing.Address,
        messagedata: bytes,
        message_id: int,
    ) -> AsyncResult:
        """ Send message to recipient if the transport is running.

        Returns:
            An AsyncResult that will be set once the message is delivered. As
            long as the message has not been acknowledged with a Delivered
            message the function will return the same AsyncResult.
        """
        async_result = self.messageids_to_asyncresults.get(message_id)
        if async_result is None:
            async_result = AsyncResult()
            self.messageids_to_asyncresults[message_id] = async_result

        host_port = self.get_host_port(recipient)
        self.maybe_sendraw(host_port, messagedata)

        return async_result

    def maybe_sendraw(self, host_port: typing.Tuple[int, int],
                      messagedata: bytes):
        """ Send message to recipient if the transport is running. """

        # Don't sleep if timeout is zero, otherwise a context-switch is done
        # and the message is delayed, increasing it's latency
        sleep_timeout = self.throttle_policy.consume(1)
        if sleep_timeout:
            gevent.sleep(sleep_timeout)

        # Check the udp socket is still available before trying to send the
        # message. There must be *no context-switches after this test*.
        if hasattr(self.server, 'socket'):
            self.server.sendto(
                messagedata,
                host_port,
            )

    def _receive(self, data, host_port):  # pylint: disable=unused-argument
        try:
            self.receive(data)
        except RaidenShuttingDown:  # For a clean shutdown
            return

    def receive(self, messagedata: bytes):
        """ Handle an UDP packet. """
        # pylint: disable=unidiomatic-typecheck

        if len(messagedata) > UDP_MAX_MESSAGE_SIZE:
            log.error(
                'INVALID MESSAGE: Packet larger than maximum size',
                node=pex(self.raiden.address),
                message=hexlify(messagedata),
                length=len(messagedata),
            )
            return

        message = decode(messagedata)

        if type(message) == Pong:
            self.receive_pong(message)
        elif type(message) == Ping:
            self.receive_ping(message)
        elif type(message) == Delivered:
            self.receive_delivered(message)
        elif message is not None:
            self.receive_message(message)
        else:
            log.error(
                'INVALID MESSAGE: Unknown cmdid',
                node=pex(self.raiden.address),
                message=hexlify(messagedata),
            )

    def receive_message(self, message: Message):
        """ Handle a Raiden protocol message.

        The protocol requires durability of the messages. The UDP transport
        relies on the node's WAL for durability. The message will be converted
        to a state change, saved to the WAL, and *processed* before the
        durability is confirmed, which is a stronger property than what is
        required of any transport.
        """
        # pylint: disable=unidiomatic-typecheck

        if on_udp_message(self.raiden, message):

            # Sending Delivered after the message is decoded and *processed*
            # gives a stronger guarantee than what is required from a
            # transport.
            #
            # Alternatives are, from weakest to strongest options:
            # - Just save it on disk and asynchronously process the messages
            # - Decode it, save to the WAL, and asynchronously process the
            #   state change
            # - Decode it, save to the WAL, and process it (the current
            #   implementation)
            delivered_message = Delivered(message.message_identifier)
            self.raiden.sign(delivered_message)

            self.maybe_send(
                message.sender,
                delivered_message,
            )

    def receive_delivered(self, delivered: Delivered):
        """ Handle a Delivered message.

        The Delivered message is how the UDP transport guarantees persistence
        by the partner node. The message itself is not part of the raiden
        protocol, but it's required by this transport to provide the required
        properties.
        """
        processed = ReceiveDelivered(delivered.delivered_message_identifier)
        self.raiden.handle_state_change(processed)

        message_id = delivered.delivered_message_identifier
        async_result = self.raiden.protocol.messageids_to_asyncresults.get(
            message_id)

        # clear the async result, otherwise we have a memory leak
        if async_result is not None:
            del self.messageids_to_asyncresults[message_id]
            async_result.set()

    # Pings and Pongs are used to check the health status of another node. They
    # are /not/ part of the raiden protocol, only part of the UDP transport,
    # therefore these messages are not forwarded to the message handler.
    def receive_ping(self, ping: Ping):
        """ Handle a Ping message by answering with a Pong. """

        log.debug(
            'PING RECEIVED',
            node=pex(self.raiden.address),
            message_id=ping.nonce,
            message=ping,
            sender=pex(ping.sender),
        )

        pong = Pong(ping.nonce)
        self.raiden.sign(pong)

        try:
            self.maybe_send(ping.sender, pong)
        except (InvalidAddress, UnknownAddress) as e:
            log.debug("Couldn't send the `Delivered` message", e=e)

    def receive_pong(self, pong: Pong):
        """ Handles a Pong message. """

        message_id = ('ping', pong.nonce, pong.sender)
        async_result = self.messageids_to_asyncresults.get(message_id)

        if async_result is not None:
            log.debug(
                'PONG RECEIVED',
                node=pex(self.raiden.address),
                message_id=pong.nonce,
            )

            async_result.set(True)

    def get_ping(self, nonce: int) -> Ping:
        """ Returns a signed Ping message.

        Note: Ping messages don't have an enforced ordering, so a Ping message
        with a higher nonce may be acknowledged first.
        """
        message = Ping(nonce)
        self.raiden.sign(message)
        message_data = message.encode()

        return message_data

    def set_node_network_state(self, node_address: typing.Address, node_state):
        state_change = ActionChangeNodeNetworkState(node_address, node_state)
        self.raiden.handle_state_change(state_change)
예제 #36
0
class ESLProtocol(object):
    def __init__(self):
        self._run = True
        self._EOL = '\n'
        self._commands_sent = []
        self._auth_request_event = Event()
        self._receive_events_greenlet = None
        self._process_events_greenlet = None
        self.event_handlers = {}
        self._esl_event_queue = Queue()
        self._process_esl_event_queue = True
        self._lingering = False
        self.connected = False

    def start_event_handlers(self):
        self._receive_events_greenlet = gevent.spawn(self.receive_events)
        self._process_events_greenlet = gevent.spawn(self.process_events)

    def register_handle(self, name, handler):
        if name not in self.event_handlers:
            self.event_handlers[name] = []
        if handler in self.event_handlers[name]:
            return
        self.event_handlers[name].append(handler)

    def unregister_handle(self, name, handler):
        if name not in self.event_handlers:
            raise ValueError('No handlers found for event: %s' % name)
        self.event_handlers[name].remove(handler)
        if not self.event_handlers[name]:
            del self.event_handlers[name]

    def receive_events(self):
        buf = ''
        while self._run:
            try:
                data = self.sock_file.readline()
            except Exception:
                self._run = False
                self.connected = False
                self.sock.close()
                # logging.exception("Error reading from socket.")
                break

            if not data:
                if self.connected:
                    logging.error(
                        "Error receiving data, is FreeSWITCH running?")
                    self.connected = False
                    self._run = False
                break
            # Empty line
            if data == self._EOL:
                event = ESLEvent(buf)
                buf = ''
                self.handle_event(event)
                continue
            buf += data

    @staticmethod
    def _read_socket(sock, length):
        """Receive data from socket until the length is reached."""
        data = sock.read(length)
        data_length = len(data)
        while data_length < length:
            logging.warn(
                'Socket should read %s bytes, but actually read %s bytes. '
                'Consider increasing "net.core.rmem_default".' %
                (length, data_length))
            # FIXME(italo): if not data raise error
            data += sock.read(length - data_length)
            data_length = len(data)
        return data

    def handle_event(self, event):
        if event.headers['Content-Type'] == 'auth/request':
            self._auth_request_event.set()
        elif event.headers['Content-Type'] == 'command/reply':
            async_response = self._commands_sent.pop(0)
            event.data = event.headers['Reply-Text']
            async_response.set(event)
        elif event.headers['Content-Type'] == 'api/response':
            length = int(event.headers['Content-Length'])
            data = self._read_socket(self.sock_file, length)
            event.data = data
            async_response = self._commands_sent.pop(0)
            async_response.set(event)
        elif event.headers['Content-Type'] == 'text/disconnect-notice':
            if event.headers.get('Content-Disposition') == 'linger':
                logging.debug('Linger activated')
                self._lingering = True
            else:
                self.connected = False
            # disconnect-notice is now a propagated event both for inbound
            # and outbound socket modes.
            # This is useful for outbound mode to notify all remaining
            # waiting commands to stop blocking and send a NotConnectedError
            self._esl_event_queue.put(event)
        elif event.headers['Content-Type'] == 'text/rude-rejection':
            self.connected = False
            length = int(event.headers['Content-Length'])
            self._read_socket(self.sock_file, length)
            self._auth_request_event.set()
        else:
            length = int(event.headers['Content-Length'])
            data = self._read_socket(self.sock_file, length)
            if event.headers.get('Content-Type') == 'log/data':
                event.data = data
            else:
                event.parse_data(data)
            self._esl_event_queue.put(event)

    def _safe_exec_handler(self, handler, event):
        try:
            handler(event)
        except:
            logging.exception('ESL %s raised exception.' % handler.__name__)
            logging.error(pprint.pformat(event.headers))

    def process_events(self):
        logging.debug('Event Processor Running')
        while self._run:
            if not self._process_esl_event_queue:
                gevent.sleep(1)
                continue

            try:
                event = self._esl_event_queue.get(timeout=1)
            except gevent.queue.Empty:
                continue

            if event.headers.get('Event-Name') == 'CUSTOM':
                handlers = self.event_handlers.get(
                    event.headers.get('Event-Subclass'))
            else:
                handlers = self.event_handlers.get(
                    event.headers.get('Event-Name'))

            if event.headers.get('Content-Type') == 'text/disconnect-notice':
                handlers = self.event_handlers.get('DISCONNECT')

            if not handlers and event.headers.get(
                    'Content-Type') == 'log/data':
                handlers = self.event_handlers.get('log')

            if not handlers and '*' in self.event_handlers:
                handlers = self.event_handlers.get('*')

            if not handlers:
                continue

            if hasattr(self, 'before_handle'):
                self._safe_exec_handler(self.before_handle, event)

            for handle in handlers:
                self._safe_exec_handler(handle, event)

            if hasattr(self, 'after_handle'):
                self._safe_exec_handler(self.after_handle, event)

    def send(self, data):
        if not self.connected:
            raise NotConnectedError()
        async_response = gevent.event.AsyncResult()
        self._commands_sent.append(async_response)
        raw_msg = (data + self._EOL * 2).encode('utf-8')
        self.sock.send(raw_msg)
        response = async_response.get()
        return response

    def stop(self):
        if self.connected:
            try:
                self.send('exit')
            except (NotConnectedError, socket.error):
                pass
        self._run = False
        logging.info("Waiting for receive greenlet exit")
        self._receive_events_greenlet.join()
        logging.info("Waiting for event processing greenlet exit")
        self._process_events_greenlet.join()
        self.sock.close()
        self.sock_file.close()
예제 #37
0
class ArticalSpider(object):
    """协程捕捉URL爬虫并解析html,将结果存入数据库
    maxsize: 队列存储的最大值(默认为1000)
    poolSize:协程池最大同时激活greenlet个数(默认为5个)
    """
    def __init__(self):
        self.evt = Event()  # 等待初始化
        self.initConfig()  # 初始化配置文件
        self.initModules()  # 初始化模块

        self.q = Queue(maxsize=self.maxsize)  # 有界队列
        self.initQueue()  # 初始化队列

        self.crawlUrlsCount = 0  # 统计搜到的链接的个数
        self.crawlerID = 0  # 协程ID标志
        self.pool = Pool(self.poolSize)  # 协程池
        self.isInitializeCompletely = False  # 是否初始化完成

        self.startTime = None  # 爬虫启动时间

    def initModules(self):
        """初始化模块"""
        logger.info('Initializing modules...')
        self.htmlParser = HtmlParser()  # 加载智能解析模块
        self.sqlManager = SQLManager()  # 加载数据库模块
        logger.info('Reading url md5 from mysql...')
        self.urlDict = self.sqlManager.getAllMd5()  # 加载已解析URL字典

    def initConfig(self):
        """读取配置文件信息"""
        logger.info('Initializing config...')
        with open('data.conf') as json_file:
            data = json.load(json_file)
            self.maxsize = data['maxUrlQueueSize']  # URL队列最大存储值
            self.poolSize = data['poolSize']  # 协程池最大同时激活greenlet个数
            self.fileName = data['urlQueueFileName']  # 队列url的保存文件名
            self.startUrls = data['startUrls']  # 队列初始化url
            self.filterUrlsRegular = data['filterUrlsRegular']  # 过滤的url
            self.saveTime = data['saveTime']  # 队列url定时保存到本地文件

    def initQueue(self):
        """初始化队列,提供起始url列表

        :param urls: url列表
        :return:
        """
        self.loadLastUrlQueue()
        for url in self.startUrls[:self.maxsize]:
            self.q.put(url)
        self.isInitializeCompletely = True
        self.evt.set()

    def loadLastUrlQueue(self):
        """加载上次保存的队列url"""
        logger.info('Initializing queue...')
        hasLastUrls = False
        if not os.path.exists(self.fileName): return hasLastUrls
        with open(self.fileName, 'rb') as f:
            for url in pickle.load(f)[:self.maxsize - 100]:
                hasLastUrls = True
                self.q.put(url.strip())  # 注意把空格删除
        return hasLastUrls

    def getCrawlUrlsCount(self):
        """返回已捕捉到的URL数量"""
        return self.crawlUrlsCount

    def getQueueSize(self):
        """返回当前队列中URL数量"""
        return self.q.qsize()

    def saveQueueUrls(self):
        """将队列内容拷贝到文件"""
        # 拷贝队列进行遍历
        logger.info('Save queue urls')
        with open(self.fileName, 'wb') as f:
            urls = list(self.q.queue)
            pickle.dump(urls, f)

    def crawlURL(self, crawlerID):
        """每个工作者,搜索新的url"""
        # 为了减少协程的切换,每个新建的工作者会不断查找URL,直到队列空或满
        # 实际上因为有界队列的原因,协程仍然会不断切换
        while True:
            if not self.isInitializeCompletely:  # 还未初始化完成则等待
                self.evt.wait()
            # 定时保存队列数据,以便下次恢复
            if time.time() - self.startTime > self.saveTime:
                self.saveQueueUrls()
                self.startTime = time.time()

            gevent.sleep(random.uniform(0, 1))  # 防止爬取频率过快
            try:
                url = self.q.get(timeout=0.1)  # 当队列空时自动释放当前greenlet
                md5_url = MD5(url)
                if md5_url in self.urlDict: continue  # 如果已存在则抛弃
                self.urlDict[md5_url] = True  # 加入字典

                headers = {
                    'User-Agent':
                    'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36',
                }
                r = requests.get(url, timeout=5, headers=headers)
                if r.status_code == 200:
                    if r.encoding == 'ISO-8859-1':
                        charset = self.detCharset(r.text)
                        if charset != "" and charset.lower() in [
                                'utf-8', 'gb2312', 'gbk'
                        ]:
                            r.encoding = charset
                        else:
                            r.encoding = chardet.detect(
                                r.content)['encoding']  # 确定网页编码

                    # 插入数据库
                    self.insertMysql(r.text, url, MD5(url))

                    # 寻找下一个url
                    for link in re.findall('<a[^>]+href="(http.*?)"', r.text):
                        if len(self.filterUrlsRegular) != 0:
                            for filterUrl in self.filterUrlsRegular:
                                if filterUrl in link:
                                    # 仅当队列中元素小于最大队列个数添加当前url到队列
                                    self.q.put(
                                        link.strip(),
                                        timeout=0.1)  # 当队列满时自动释放当前greenlet
                                    self.crawlUrlsCount += 1
                                    break
                        else:
                            if len(link.strip()) != 0:
                                self.q.put(link.strip(), timeout=0.1)
                                self.crawlUrlsCount += 1

                else:
                    logger.warning('Request error status: ' +
                                   str(r.status_code) + ': ' + url)
                    # 这里可以进行重连(这里不写了)

            except Empty:  # q.get()时队列为空异常
                # logger.info('URL Queue is Empty! URLSpider-' + str(crawlerID) + ': stopping crawler...')
                break
            except Full:  # q.put()时队列为满异常
                # logger.info('URL Queue is Full! URLSpider-' + str(crawlerID) + ': stopping crawler...')
                break
            except requests.exceptions.ConnectionError:  # 连接数过高,程序休眠
                logger.warning('Connection refused')
                time.sleep(3)
            except requests.exceptions.ReadTimeout:  # 超时
                logger.warning('Request readTimeout')
                # 接下去可以尝试重连,这里不写了

    def insertMysql(self, html, url, md5):
        """将解析结果插入队列"""
        parseDict = self.htmlParser.extract_offline(html)
        content = parseDict['content']
        description = parseDict['description']
        keyword = parseDict['keyword']
        title = parseDict['title']
        # 插入数据库
        if content != "":
            self.sqlManager.insert(
                Artical(content=content,
                        title=title,
                        keyword=keyword,
                        description=description,
                        url=url,
                        md5=md5))
            logger.info('Insert Mysql: ' + url)

    def detCharset(self, html):
        """检测网页编码"""
        charsetPattern = re.compile(
            '<\s*meta[^>]*?charset=["]?(.*?)"?\s*[/]>?', re.I | re.S)
        charset = charsetPattern.search(html)
        if charset: charset = charset.groups()[0]
        else: charset = ""
        return charset

    def run(self):
        """开启协程池,运行爬虫,在队列中无url时退出捕获"""
        if self.q.qsize() == 0:
            logger.error('Please init Queue first (Check your .conf file)')
            return
        logger.info('Starting crawler...')
        self.startTime = time.time()
        while True:
            # 当没有任何协程在工作,且队列中无url时退出捕获
            if self.q.empty() and self.pool.free_count() == self.poolSize:
                break

            # 每次创建和队列中url个数一样的协程数
            # 如果协程池所能同时工作的协程数小于url个数,则创建协程池所能同时工作的最大协程数
            # 保证协程池总是在最多激活greenlet数状态
            for _ in range(min(self.pool.free_count(), self.q.qsize())):
                self.crawlerID += 1
                self.pool.spawn(self.crawlURL, self.crawlerID)

            # 切换协程(因为只在遇到I/O才会自动切换协程)
            gevent.sleep(0.1)
        logger.warning('All crawler stopping...')
예제 #38
0
class RaidenService:
    """ A Raiden node. """
    def __init__(
        self,
        chain: BlockChainService,
        query_start_block: typing.BlockNumber,
        default_registry: TokenNetworkRegistry,
        default_secret_registry: SecretRegistry,
        private_key_bin,
        transport,
        config,
        discovery=None,
    ):
        if not isinstance(private_key_bin,
                          bytes) or len(private_key_bin) != 32:
            raise ValueError('invalid private_key')

        self.tokennetworkids_to_connectionmanagers = dict()
        self.identifier_to_results: typing.Dict[typing.PaymentIdentifier,
                                                AsyncResult, ] = dict()

        self.chain: BlockChainService = chain
        self.default_registry = default_registry
        self.query_start_block = query_start_block
        self.default_secret_registry = default_secret_registry
        self.config = config
        self.privkey = private_key_bin
        self.address = privatekey_to_address(private_key_bin)
        self.discovery = discovery

        self.private_key = PrivateKey(private_key_bin)
        self.pubkey = self.private_key.public_key.format(compressed=False)
        self.transport = transport

        self.blockchain_events = BlockchainEvents()
        self.alarm = AlarmTask(chain)
        self.shutdown_timeout = config['shutdown_timeout']
        self.stop_event = Event()
        self.start_event = Event()
        self.chain.client.inject_stop_event(self.stop_event)

        self.wal = None
        self.snapshot_group = 0

        # This flag will be used to prevent the service from processing
        # state changes events until we know that pending transactions
        # have been dispatched.
        self.dispatch_events_lock = Semaphore(1)

        self.database_path = config['database_path']
        if self.database_path != ':memory:':
            database_dir = os.path.dirname(config['database_path'])
            os.makedirs(database_dir, exist_ok=True)

            self.database_dir = database_dir
            # Prevent concurrent access to the same db
            self.lock_file = os.path.join(self.database_dir, '.lock')
            self.db_lock = filelock.FileLock(self.lock_file)
        else:
            self.database_path = ':memory:'
            self.database_dir = None
            self.lock_file = None
            self.serialization_file = None
            self.db_lock = None

        self.event_poll_lock = gevent.lock.Semaphore()

    def start_async(self) -> Event:
        """ Start the node asynchronously. """
        self.start_event.clear()
        self.stop_event.clear()

        if self.database_dir is not None:
            self.db_lock.acquire(timeout=0)
            assert self.db_lock.is_locked

        # start the registration early to speed up the start
        if self.config['transport_type'] == 'udp':
            endpoint_registration_greenlet = gevent.spawn(
                self.discovery.register,
                self.address,
                self.config['transport']['udp']['external_ip'],
                self.config['transport']['udp']['external_port'],
            )

        # The database may be :memory:
        storage = sqlite.SQLiteStorage(self.database_path,
                                       serialize.PickleSerializer())
        self.wal = wal.restore_from_latest_snapshot(
            node.state_transition,
            storage,
        )

        if self.wal.state_manager.current_state is None:
            log.debug('No recoverable state available, created inital state')
            block_number = self.chain.block_number()

            state_change = ActionInitChain(
                random.Random(),
                block_number,
                self.chain.node_address,
                self.chain.network_id,
            )
            self.wal.log_and_dispatch(state_change, block_number)
            payment_network = PaymentNetworkState(
                self.default_registry.address,
                [],  # empty list of token network states as it's the node's startup
            )
            state_change = ContractReceiveNewPaymentNetwork(
                constants.NULL_HASH_BYTES,
                payment_network,
            )
            self.handle_state_change(state_change)

            # On first run Raiden needs to fetch all events for the payment
            # network, to reconstruct all token network graphs and find opened
            # channels
            last_log_block_number = 0
        else:
            # The `Block` state change is dispatched only after all the events
            # for that given block have been processed, filters can be safely
            # installed starting from this position without losing events.
            last_log_block_number = views.block_number(
                self.wal.state_manager.current_state)
            log.debug('Restored state from WAL',
                      last_restored_block=last_log_block_number)

        # Restore the current snapshot group
        self.snapshot_group = last_log_block_number // SNAPSHOT_BLOCK_COUNT

        # Install the filters using the correct from_block value, otherwise
        # blockchain logs can be lost.
        self.install_all_blockchain_filters(
            self.default_registry,
            self.default_secret_registry,
            last_log_block_number,
        )

        # Complete the first_run of the alarm task and synchronize with the
        # blockchain since the last run.
        #
        # Notes about setup order:
        # - The filters must be polled after the node state has been primed,
        # otherwise the state changes won't have effect.
        # - The alarm must complete its first run  before the transport is started,
        #  to avoid rejecting messages for unknown channels.
        self.alarm.register_callback(self._callback_new_block)

        self.alarm.first_run()

        chain_state = views.state_from_raiden(self)
        # Dispatch pending transactions
        pending_transactions = views.get_pending_transactions(chain_state, )
        log.debug(
            'Processing pending transactions',
            num_pending_transactions=len(pending_transactions),
        )
        with self.dispatch_events_lock:
            for transaction in pending_transactions:
                on_raiden_event(self, transaction)

        self.alarm.start()

        queueids_to_queues = views.get_all_messagequeues(chain_state)
        self.transport.start(self, queueids_to_queues)

        # Health check needs the transport layer
        self.start_neighbours_healthcheck()

        if self.config['transport_type'] == 'udp':

            def set_start_on_registration(_):
                self.start_event.set()

            endpoint_registration_greenlet.link(set_start_on_registration)
        else:
            self.start_event.set()

        return self.start_event

    def start(self) -> Event:
        """ Start the node. """
        self.start_async().wait()

    def start_neighbours_healthcheck(self):
        for neighbour in views.all_neighbour_nodes(
                self.wal.state_manager.current_state):
            if neighbour != ConnectionManager.BOOTSTRAP_ADDR:
                self.start_health_check_for(neighbour)

    def stop(self):
        """ Stop the node. """
        # Needs to come before any greenlets joining
        self.stop_event.set()
        self.transport.stop_and_wait()
        self.alarm.stop_async()

        wait_for = [self.alarm]
        wait_for.extend(getattr(self.transport, 'greenlets', []))
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        gevent.wait(wait_for, timeout=self.shutdown_timeout)

        # Filters must be uninstalled after the alarm task has stopped. Since
        # the events are polled by an alarm task callback, if the filters are
        # uninstalled before the alarm task is fully stopped the callback
        # `poll_blockchain_events` will fail.
        #
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        try:
            with gevent.Timeout(self.shutdown_timeout):
                self.blockchain_events.uninstall_all_event_listeners()
        except (gevent.timeout.Timeout, RaidenShuttingDown):
            pass

        self.blockchain_events.reset()

        if self.db_lock is not None:
            self.db_lock.release()

    def __repr__(self):
        return '<{} {}>'.format(self.__class__.__name__, pex(self.address))

    def get_block_number(self):
        return views.block_number(self.wal.state_manager.current_state)

    def handle_state_change(self, state_change, block_number=None):
        log.debug('STATE CHANGE',
                  node=pex(self.address),
                  state_change=state_change)

        if block_number is None:
            block_number = self.get_block_number()

        # Take a snapshot every SNAPSHOT_BLOCK_COUNT
        # TODO: Gather more data about storage requirements
        # and update the value to specify how often we need
        # capturing a snapshot should take place
        new_snapshot_group = block_number // SNAPSHOT_BLOCK_COUNT
        if new_snapshot_group > self.snapshot_group:
            log.debug(f'Storing snapshot at block: {block_number}')
            self.wal.snapshot()
            self.snapshot_group = new_snapshot_group

        event_list = self.wal.log_and_dispatch(state_change, block_number)

        if self.dispatch_events_lock.locked():
            return []

        for event in event_list:
            log.debug('RAIDEN EVENT',
                      node=pex(self.address),
                      raiden_event=event)

            on_raiden_event(self, event)

        return event_list

    def set_node_network_state(self, node_address, network_state):
        state_change = ActionChangeNodeNetworkState(node_address,
                                                    network_state)
        self.wal.log_and_dispatch(state_change, self.get_block_number())

    def start_health_check_for(self, node_address):
        self.transport.start_health_check(node_address)

    def _callback_new_block(self, current_block_number, chain_id):
        """Called once a new block is detected by the alarm task.

        Note:
            This should be called only once per block, otherwise there will be
            duplicated `Block` state changes in the log.

            Therefore this method should be called only once a new block is
            mined with the appropriate block_number argument from the
            AlarmTask.
        """
        # Raiden relies on blockchain events to update its off-chain state,
        # therefore some APIs /used/ to forcefully poll for events.
        #
        # This was done for APIs which have on-chain side-effects, e.g.
        # openning a channel, where polling the event is required to update
        # off-chain state to providing a consistent view to the caller, e.g.
        # the channel exists after the API call returns.
        #
        # That pattern introduced a race, because the events are returned only
        # once per filter, and this method would be called concurrently by the
        # API and the AlarmTask. The following lock is necessary, to ensure the
        # expected side-effects are properly applied (introduced by the commit
        # 3686b3275ff7c0b669a6d5e2b34109c3bdf1921d)
        with self.event_poll_lock:
            for event in self.blockchain_events.poll_blockchain_events(
                    current_block_number):
                # These state changes will be procesed with a block_number
                # which is /larger/ than the ChainState's block_number.
                on_blockchain_event(self, event, current_block_number,
                                    chain_id)

            # On restart the Raiden node will re-create the filters with the
            # ethereum node. These filters will have the from_block set to the
            # value of the latest Block state change. To avoid missing events
            # the Block state change is dispatched only after all of the events
            # have been processed.
            #
            # This means on some corner cases a few events may be applied
            # twice, this will happen if the node crashed and some events have
            # been processed but the Block state change has not been
            # dispatched.
            state_change = Block(current_block_number)
            self.handle_state_change(state_change, current_block_number)

    def sign(self, message):
        """ Sign message inplace. """
        if not isinstance(message, SignedMessage):
            raise ValueError('{} is not signable.'.format(repr(message)))

        message.sign(self.private_key)

    def install_all_blockchain_filters(
        self,
        token_network_registry_proxy: TokenNetworkRegistry,
        secret_registry_proxy: SecretRegistry,
        from_block: typing.BlockNumber,
    ):
        with self.event_poll_lock:
            node_state = views.state_from_raiden(self)
            token_networks = views.get_token_network_identifiers(
                node_state,
                token_network_registry_proxy.address,
            )

            self.blockchain_events.add_token_network_registry_listener(
                token_network_registry_proxy,
                from_block,
            )
            self.blockchain_events.add_secret_registry_listener(
                secret_registry_proxy,
                from_block,
            )

            for token_network in token_networks:
                token_network_proxy = self.chain.token_network(token_network)
                self.blockchain_events.add_token_network_listener(
                    token_network_proxy,
                    from_block,
                )

    def connection_manager_for_token_network(self, token_network_identifier):
        if not is_binary_address(token_network_identifier):
            raise InvalidAddress('token address is not valid.')

        known_token_networks = views.get_token_network_identifiers(
            views.state_from_raiden(self),
            self.default_registry.address,
        )

        if token_network_identifier not in known_token_networks:
            raise InvalidAddress('token is not registered.')

        manager = self.tokennetworkids_to_connectionmanagers.get(
            token_network_identifier)

        if manager is None:
            manager = ConnectionManager(self, token_network_identifier)
            self.tokennetworkids_to_connectionmanagers[
                token_network_identifier] = manager

        return manager

    def leave_all_token_networks(self):
        state_change = ActionLeaveAllNetworks()
        self.wal.log_and_dispatch(state_change, self.get_block_number())

    def close_and_settle(self):
        log.info('raiden will close and settle all channels now')

        self.leave_all_token_networks()

        connection_managers = [
            cm for cm in self.tokennetworkids_to_connectionmanagers.values()
        ]

        if connection_managers:
            waiting.wait_for_settle_all_channels(
                self,
                self.alarm.sleep_time,
            )

    def mediated_transfer_async(
        self,
        token_network_identifier,
        amount,
        target,
        identifier,
    ):
        """ Transfer `amount` between this node and `target`.

        This method will start an asynchronous transfer, the transfer might fail
        or succeed depending on a couple of factors:

            - Existence of a path that can be used, through the usage of direct
              or intermediary channels.
            - Network speed, making the transfer sufficiently fast so it doesn't
              expire.
        """

        async_result = self.start_mediated_transfer(
            token_network_identifier,
            amount,
            target,
            identifier,
        )

        return async_result

    def direct_transfer_async(self, token_network_identifier, amount, target,
                              identifier):
        """ Do a direct transfer with target.

        Direct transfers are non cancellable and non expirable, since these
        transfers are a signed balance proof with the transferred amount
        incremented.

        Because the transfer is non cancellable, there is a level of trust with
        the target. After the message is sent the target is effectively paid
        and then it is not possible to revert.

        The async result will be set to False iff there is no direct channel
        with the target or the payer does not have balance to complete the
        transfer, otherwise because the transfer is non expirable the async
        result *will never be set to False* and if the message is sent it will
        hang until the target node acknowledge the message.

        This transfer should be used as an optimization, since only two packets
        are required to complete the transfer (from the payers perspective),
        whereas the mediated transfer requires 6 messages.
        """

        self.start_health_check_for(target)

        if identifier is None:
            identifier = create_default_identifier()

        direct_transfer = ActionTransferDirect(
            token_network_identifier,
            target,
            identifier,
            amount,
        )

        self.handle_state_change(direct_transfer)

    def start_mediated_transfer(
        self,
        token_network_identifier,
        amount,
        target,
        identifier,
    ):

        self.start_health_check_for(target)

        if identifier is None:
            identifier = create_default_identifier()

        if identifier in self.identifier_to_results:
            return self.identifier_to_results[identifier]

        async_result = AsyncResult()
        self.identifier_to_results[identifier] = async_result

        secret = random_secret()
        init_initiator_statechange = initiator_init(
            self,
            identifier,
            amount,
            secret,
            token_network_identifier,
            target,
        )

        # Dispatch the state change even if there are no routes to create the
        # wal entry.
        self.handle_state_change(init_initiator_statechange)

        return async_result

    def mediate_mediated_transfer(self, transfer: LockedTransfer):
        init_mediator_statechange = mediator_init(self, transfer)
        self.handle_state_change(init_mediator_statechange)

    def target_mediated_transfer(self, transfer: LockedTransfer):
        self.start_health_check_for(transfer.initiator)
        init_target_statechange = target_init(transfer)
        self.handle_state_change(init_target_statechange)
예제 #39
0
class DaemonWatchdog(Greenlet):
    """
    DaemonWatchdog::

    Watch Ceph daemons for failures. If an extended failure is detected (i.e.
    not intentional), then the watchdog will unmount file systems and send
    SIGTERM to all daemons. The duration of an extended failure is configurable
    with watchdog_daemon_timeout.

    watchdog_daemon_timeout [default: 300]: number of seconds a daemon
        is allowed to be failed before the watchdog will bark.
    """
    def __init__(self, ctx, manager, config, thrashers):
        Greenlet.__init__(self)
        self.ctx = ctx
        self.config = config
        self.e = None
        self.logger = log.getChild('daemon_watchdog')
        self.manager = manager
        self.name = 'watchdog'
        self.stopping = Event()
        self.thrashers = thrashers

    def _run(self):
        try:
            self.watch()
        except Exception as e:
            # See _run exception comment for MDSThrasher
            self.e = e
            self.logger.exception("exception:")
            # allow successful completion so gevent doesn't see an exception...

    def log(self, x):
        """Write data to logger"""
        self.logger.info(x)

    def stop(self):
        self.stopping.set()

    def bark(self):
        self.log("BARK! unmounting mounts and killing all daemons")
        for mount in self.ctx.mounts.values():
            try:
                mount.umount_wait(force=True)
            except:
                self.logger.exception("ignoring exception:")
        daemons = []
        daemons.extend(
            filter(
                lambda daemon: daemon.running() and not daemon.proc.finished,
                self.ctx.daemons.iter_daemons_of_role(
                    'mds', cluster=self.manager.cluster)))
        daemons.extend(
            filter(
                lambda daemon: daemon.running() and not daemon.proc.finished,
                self.ctx.daemons.iter_daemons_of_role(
                    'mon', cluster=self.manager.cluster)))
        for daemon in daemons:
            try:
                daemon.signal(signal.SIGTERM)
            except:
                self.logger.exception("ignoring exception:")

    def watch(self):
        self.log("watchdog starting")
        daemon_timeout = int(self.config.get('watchdog_daemon_timeout', 300))
        daemon_failure_time = {}
        while not self.stopping.is_set():
            bark = False
            now = time.time()

            mons = self.ctx.daemons.iter_daemons_of_role(
                'mon', cluster=self.manager.cluster)
            mdss = self.ctx.daemons.iter_daemons_of_role(
                'mds', cluster=self.manager.cluster)
            clients = self.ctx.daemons.iter_daemons_of_role(
                'client', cluster=self.manager.cluster)

            #for daemon in mons:
            #    self.log("mon daemon {role}.{id}: running={r}".format(role=daemon.role, id=daemon.id_, r=daemon.running() and not daemon.proc.finished))
            #for daemon in mdss:
            #    self.log("mds daemon {role}.{id}: running={r}".format(role=daemon.role, id=daemon.id_, r=daemon.running() and not daemon.proc.finished))

            daemon_failures = []
            daemon_failures.extend(
                filter(
                    lambda daemon: daemon.running() and daemon.proc.finished,
                    mons))
            daemon_failures.extend(
                filter(
                    lambda daemon: daemon.running() and daemon.proc.finished,
                    mdss))
            for daemon in daemon_failures:
                name = daemon.role + '.' + daemon.id_
                dt = daemon_failure_time.setdefault(name, (daemon, now))
                assert dt[0] is daemon
                delta = now - dt[1]
                self.log("daemon {name} is failed for ~{t:.0f}s".format(
                    name=name, t=delta))
                if delta > daemon_timeout:
                    bark = True

            # If a daemon is no longer failed, remove it from tracking:
            for name in daemon_failure_time.keys():
                if name not in [d.role + '.' + d.id_ for d in daemon_failures]:
                    self.log(
                        "daemon {name} has been restored".format(name=name))
                    del daemon_failure_time[name]

            for thrasher in self.thrashers:
                if thrasher.e is not None:
                    self.log("thrasher on fs.{name} failed".format(
                        name=thrasher.fs.name))
                    bark = True

            if bark:
                self.bark()
                return

            sleep(5)

        self.log("watchdog finished")
예제 #40
0
class TriggeredService(Service):
    """A service receiving notifications to perform an operation.

    This is a base class implementing a common pattern in CMS: a
    service performing operations when certain conditions are met.

    Often, the operation is "do something on object x" and the
    condition is a condition on the fields of object x, but this
    pattern is not included in this class, the operation have no need
    to be in that form.

    The pattern is implemented through the following blocks.
    - The method enqueue, which schedules a new operation. This can be
      used by subclasses when they receive a notification that an
      operations is needed, or in any other contexts.
    - A sweeper greenlet that asks subclasses to search and enqueue
      operations that were missed by the previous step.
    - A list of executors (each running in its own greenlet), each of
      which takes care of performing all operations.

    Note that if there are multiple executors, each operation will be
    executed by all of them. Indeed, having multiple executors is
    required when we need to execute slightly different versions of
    the same operation depending on local variables in the
    executors. For example, sending data to different machines.

    If required, subclasses can override enqueue() to change this
    behavior, for example to dispatch different operations to
    different executors.

    """
    def __init__(self, shard):
        """Initialize the sweeper loop.

        shard (int): which service shard to run.

        """
        Service.__init__(self, shard)

        self._executors = []

        self._sweeper_start = None
        self._sweeper_event = Event()
        self._sweeper_started = False
        self._sweeper_timeout = None

    def add_executor(self, executor):
        """Add an executor for the service.

        """
        # Set up and spawn the executors.
        #
        # TODO: link to greenlet and react to deaths.
        self._executors.append(executor)
        gevent.spawn(executor.run)

    def get_executor(self):
        """Return the first executor (without checking it is unique).

        return (Executor): the first executor.

        """
        return self._executors[0]

    def enqueue(self, operation, priority=None, timestamp=None):
        """Add an operation to the queue of each executor.

        operation (QueueItem): the operation to enqueue.
        priority (int|None) the priority, or None to use default.
        timestamp (datetime|None) the timestamp of the first request
            for the operation, or None to use now.

        return (int): the number of executors that successfully added
            the operation to their queue.

        """
        ret = 0
        for executor in self._executors:
            if executor.enqueue(operation, priority, timestamp):
                ret += 1
        return ret

    def dequeue(self, operation):
        """Remove an operation from the queue of each executor.

        operation (QueueItem): the operation to dequeue.

        """
        for executor in self._executors:
            executor.dequeue(operation)

    def start_sweeper(self, timeout):
        """Start sweeper loop with given timeout.

        timeout (float): timeout in seconds.

        """
        if not self._sweeper_started:
            self._sweeper_started = True
            self._sweeper_timeout = timeout

            # TODO: link to greenlet and react to its death.
            gevent.spawn(self._sweeper_loop)
        else:
            logger.warning("Service tried to start the sweeper loop twice.")

    def _sweeper_loop(self):
        """Regularly check for missed operations.

        Run the sweep once every _sweeper_timeout seconds but make
        sure that no two sweeps run simultaneously. That is, start a
        new sweep _sweeper_timeout seconds after the previous one
        started or when the previous one finished, whatever comes
        last.

        The search_operations_not_done RPC method can interfere with
        this regularity, as it tries to run a sweeper as soon as
        possible: immediately, if no sweeper is running, or as soon as
        the current one terminates.

        Any error during the sweep is sent to the logger and then
        suppressed, because the loop must go on.

        """
        while True:
            self._sweeper_start = time.monotonic()
            self._sweeper_event.clear()

            try:
                self._sweep()
            except Exception:
                logger.error(
                    "Unexpected error when searching for missed "
                    "operations.",
                    exc_info=True)

            self._sweeper_event.wait(
                max(
                    self._sweeper_start + self._sweeper_timeout -
                    time.monotonic(), 0))

    def _sweep(self):
        """Check for missed operations."""
        logger.info("Start looking for missing operations.")
        start_time = time.time()
        counter = self._missing_operations()
        logger.info("Found %d missed operation(s) in %d ms.", counter,
                    (time.time() - start_time) * 1000)

    def _missing_operations(self):
        """Enqueue missed operations, and return their number.

        The service is suppose to enqueue all operations that needs to
        be done, and return the number of operations enqueued.

        return (int): the number of operations enqueued.

        """
        return 0

    @rpc_method
    def search_operations_not_done(self):
        """Make the sweeper loop fire the sweeper as soon as possible."""
        self._sweeper_event.set()

    @rpc_method
    def queue_status(self):
        """Return the status of the queues.

        More precisely, a list indexed by each executor, whose
        elements are the list of entries in the executor's queue. The
        first item is the top item, the others are not in order.

        return ([[QueueEntry]]): the list with the queued elements.

        """
        return [executor.get_status() for executor in self._executors]
예제 #41
0
class RaidenProtocol:
    """ Encode the message into a packet and send it.

    Each message received is stored by hash and if it is received twice the
    previous answer is resent.

    Repeat sending messages until an acknowledgment is received or the maximum
    number of retries is hit.
    """
    def __init__(self, transport, discovery, raiden, retry_interval,
                 retries_before_backoff, nat_keepalive_retries,
                 nat_keepalive_timeout, nat_invitation_timeout):

        self.transport = transport
        self.discovery = discovery
        self.raiden = raiden

        self.retry_interval = retry_interval
        self.retries_before_backoff = retries_before_backoff

        self.nat_keepalive_retries = nat_keepalive_retries
        self.nat_keepalive_timeout = nat_keepalive_timeout
        self.nat_invitation_timeout = nat_invitation_timeout

        self.event_stop = Event()

        self.channel_queue = dict()  # TODO: Change keys to the channel address
        self.greenlets = list()
        self.addresses_events = dict()

        # Maps the echohash of received and *sucessfully* processed messages to
        # its `Processed` message, used to ignored duplicate messages and resend the
        # `Processed` message.
        self.receivedhashes_to_processedmessages = dict()

        # Maps the echohash to a SentMessageState
        self.senthashes_to_states = dict()

        # Maps the addresses to a dict with the latest nonce (using a dict
        # because python integers are immutable)
        self.nodeaddresses_to_nonces = dict()

        cache = cachetools.TTLCache(
            maxsize=50,
            ttl=CACHE_TTL,
        )
        cache_wrapper = cachetools.cached(cache=cache)
        self.get_host_port = cache_wrapper(discovery.get)

    def start(self):
        self.transport.start()

    def stop_and_wait(self):
        # Stop handling incoming packets, but don't close the socket. The
        # socket can only be safely closed after all outgoing tasks are stopped
        self.transport.stop_accepting()

        # Stop processing the outgoing queues
        self.event_stop.set()
        gevent.wait(self.greenlets)

        # All outgoing tasks are stopped. Now it's safe to close the socket. At
        # this point there might be some incoming message being processed,
        # keeping the socket open is not useful for these.
        self.transport.stop()

        # Set all the pending results to False
        for wait_processed in self.senthashes_to_states.values():
            wait_processed.async_result.set(False)

    def get_health_events(self, receiver_address):
        """ Starts a healthcheck taks for `receiver_address` and returns a
        HealthEvents with locks to react on its current state.
        """
        if receiver_address not in self.addresses_events:
            self.start_health_check(receiver_address)

        return self.addresses_events[receiver_address]

    def start_health_check(self, receiver_address):
        """ Starts a task for healthchecking `receiver_address` if there is not
        one yet.
        """
        if receiver_address not in self.addresses_events:
            ping_nonce = self.nodeaddresses_to_nonces.setdefault(
                receiver_address,
                {'nonce': 0},  # HACK: Allows the task to mutate the object
            )

            events = HealthEvents(
                event_healthy=Event(),
                event_unhealthy=Event(),
            )

            self.addresses_events[receiver_address] = events

            self.greenlets.append(
                gevent.spawn(
                    healthcheck,
                    self,
                    receiver_address,
                    self.event_stop,
                    events.event_healthy,
                    events.event_unhealthy,
                    self.nat_keepalive_retries,
                    self.nat_keepalive_timeout,
                    self.nat_invitation_timeout,
                    ping_nonce,
                ))

    def get_channel_queue(self, receiver_address, token_address):
        key = (
            receiver_address,
            token_address,
        )

        if key in self.channel_queue:
            return self.channel_queue[key]

        queue = NotifyingQueue()
        self.channel_queue[key] = queue

        events = self.get_health_events(receiver_address)

        self.greenlets.append(
            gevent.spawn(
                single_queue_send,
                self,
                receiver_address,
                queue,
                self.event_stop,
                events.event_healthy,
                events.event_unhealthy,
                self.retries_before_backoff,
                self.retry_interval,
                self.retry_interval * 10,
            ))

        if log.isEnabledFor(logging.DEBUG):
            log.debug(
                'new queue created for',
                node=pex(self.raiden.address),
                token=pex(token_address),
                to=pex(receiver_address),
            )

        return queue

    def send_async(self, receiver_address, message):
        if not isaddress(receiver_address):
            raise ValueError('Invalid address {}'.format(
                pex(receiver_address)))

        if isinstance(message, (Processed, Ping)):
            raise ValueError(
                'Do not use send for `Processed` or `Ping` messages')

        # Messages that are not unique per receiver can result in hash
        # collision, e.g. Secret messages. The hash collision has the undesired
        # effect of aborting message resubmission once /one/ of the nodes
        # replied with an Ack, adding the receiver address into the echohash to
        # avoid these collisions.
        messagedata = message.encode()
        echohash = sha3(messagedata + receiver_address)

        if len(messagedata) > UDP_MAX_MESSAGE_SIZE:
            raise ValueError('message size exceeds the maximum {}'.format(
                UDP_MAX_MESSAGE_SIZE))

        # All messages must be ordered, but only on a per channel basis.
        token_address = getattr(message, 'token', b'')

        # Ignore duplicated messages
        if echohash not in self.senthashes_to_states:
            async_result = AsyncResult()
            self.senthashes_to_states[echohash] = SentMessageState(
                async_result,
                receiver_address,
            )

            queue = self.get_channel_queue(
                receiver_address,
                token_address,
            )

            if log.isEnabledFor(logging.DEBUG):
                log.debug(
                    'SENDING MESSAGE',
                    to=pex(receiver_address),
                    node=pex(self.raiden.address),
                    message=message,
                    echohash=pex(echohash),
                )

            queue.put(messagedata)
        else:
            wait_processed = self.senthashes_to_states[echohash]
            async_result = wait_processed.async_result

        return async_result

    def send_and_wait(self, receiver_address, message, timeout=None):
        """Sends a message and wait for the response 'Processed' message."""
        async_result = self.send_async(receiver_address, message)
        return async_result.wait(timeout=timeout)

    def maybe_send_processed(self, receiver_address, processed_message):
        """ Send processed_message to receiver_address if the transport is running. """
        if not isaddress(receiver_address):
            raise InvalidAddress('Invalid address {}'.format(
                pex(receiver_address)))

        if not isinstance(processed_message, Processed):
            raise ValueError(
                'Use _maybe_send_processed only for `Processed` messages')

        messagedata = processed_message.encode()

        self.receivedhashes_to_processedmessages[processed_message.echo] = (
            receiver_address, messagedata)

        self._maybe_send_processed(
            *self.receivedhashes_to_processedmessages[processed_message.echo])

    def _maybe_send_processed(self, receiver_address, messagedata):
        """ `Processed` messages must not go into the queue, otherwise nodes will deadlock
        waiting for the confirmation.
        """
        host_port = self.get_host_port(receiver_address)

        # `Processed` messages are sent at the end of the receive method, after the message is
        # sucessfully processed. It may be the case that the server is stopped
        # after the message is received but before the processed message is sent, under that
        # circumstance the udp socket would be unavaiable and then an exception
        # is raised.
        #
        # This check verifies the udp socket is still available before trying
        # to send the `Processed` message. There must be *no context-switches after this test*.
        if self.transport.server.started:
            self.transport.send(
                self.raiden,
                host_port,
                messagedata,
            )

    def get_ping(self, nonce):
        """ Returns a signed Ping message.

        Note: Ping messages don't have an enforced ordering, so a Ping message
        with a higher nonce may be acknowledged first.
        """
        message = Ping(nonce)
        self.raiden.sign(message)
        message_data = message.encode()

        return message_data

    def send_raw_with_result(self, data, receiver_address):
        """ Sends data to receiver_address and returns an AsyncResult that will
        be set once the message is acknowledged.

        Always returns same AsyncResult instance for equal input.
        """
        host_port = self.get_host_port(receiver_address)
        echohash = sha3(data + receiver_address)

        if echohash not in self.senthashes_to_states:
            async_result = AsyncResult()
            self.senthashes_to_states[echohash] = SentMessageState(
                async_result,
                receiver_address,
            )
        else:
            async_result = self.senthashes_to_states[echohash].async_result

        if not async_result.ready():
            self.transport.send(
                self.raiden,
                host_port,
                data,
            )

        return async_result

    def set_node_network_state(self, node_address, node_state):
        state_change = ActionChangeNodeNetworkState(node_address, node_state)
        self.raiden.handle_state_change(state_change)

    def receive(self, data):
        if len(data) > UDP_MAX_MESSAGE_SIZE:
            log.error('receive packet larger than maximum size',
                      length=len(data))
            return

        # Repeat the 'PROCESSED' message if the message has been handled before
        echohash = sha3(data + self.raiden.address)
        if echohash in self.receivedhashes_to_processedmessages:
            self._maybe_send_processed(
                *self.receivedhashes_to_processedmessages[echohash])
            return

        message = decode(data)

        if isinstance(message, Processed):
            self.receive_processed(message)

        elif isinstance(message, Ping):
            self.receive_ping(message, echohash)

        elif isinstance(message, SignedMessage):
            self.receive_message(message, echohash)

        elif log.isEnabledFor(logging.ERROR):
            log.error(
                'Invalid message',
                message=hexlify(data),
            )

    def receive_processed(self, processed):
        waitprocessed = self.senthashes_to_states.get(processed.echo)

        if waitprocessed is None:
            if log.isEnabledFor(logging.DEBUG):
                log.debug(
                    '`Processed` MESSAGE UNKNOWN ECHO',
                    node=pex(self.raiden.address),
                    echohash=pex(processed.echo),
                )

        else:
            if log.isEnabledFor(logging.DEBUG):
                log.debug(
                    '`Processed` MESSAGE RECEIVED',
                    node=pex(self.raiden.address),
                    receiver=pex(waitprocessed.receiver_address),
                    echohash=pex(processed.echo),
                )

            waitprocessed.async_result.set(True)

    def receive_ping(self, ping, echohash):
        if ping_log.isEnabledFor(logging.DEBUG):
            ping_log.debug(
                'PING RECEIVED',
                node=pex(self.raiden.address),
                echohash=pex(echohash),
                message=ping,
                sender=pex(ping.sender),
            )

        processed_message = Processed(
            self.raiden.address,
            echohash,
        )

        try:
            self.maybe_send_processed(
                ping.sender,
                processed_message,
            )
        except (InvalidAddress, UnknownAddress) as e:
            log.debug("Couldn't send the `Processed` message", e=e)

    def receive_message(self, message, echohash):
        is_debug_log_enabled = log.isEnabledFor(logging.DEBUG)

        if is_debug_log_enabled:
            log.info('MESSAGE RECEIVED',
                     node=pex(self.raiden.address),
                     echohash=pex(echohash),
                     message=message,
                     message_sender=pex(message.sender))

        try:
            on_udp_message(self.raiden, message)

            # only send the Processed message if the message was handled without exceptions
            processed_message = Processed(
                self.raiden.address,
                echohash,
            )

            self.maybe_send_processed(
                message.sender,
                processed_message,
            )
        except (InvalidAddress, UnknownAddress, UnknownTokenAddress) as e:
            if is_debug_log_enabled:
                log.warn(str(e))
        else:
            if is_debug_log_enabled:
                log.debug(
                    'PROCESSED',
                    node=pex(self.raiden.address),
                    to=pex(message.sender),
                    echohash=pex(echohash),
                )
예제 #42
0
class BaseServer(object):
    """
    An abstract base class that implements some common functionality for the servers in gevent.

    :param listener: Either be an address that the server should bind
        on or a :class:`gevent.socket.socket` instance that is already
        bound (and put into listening mode in case of TCP socket).

    :keyword handle: If given, the request handler. The request
        handler can be defined in a few ways. Most commonly,
        subclasses will implement a ``handle`` method as an
        instance method. Alternatively, a function can be passed
        as the ``handle`` argument to the constructor. In either
        case, the handler can later be changed by calling
        :meth:`set_handle`.

        When the request handler returns, the socket used for the
        request will be closed. Therefore, the handler must not return if
        the socket is still in use (for example, by manually spawned greenlets).

    :keyword spawn: If provided, is called to create a new
        greenlet to run the handler. By default,
        :func:`gevent.spawn` is used (meaning there is no
        artificial limit on the number of concurrent requests). Possible values for *spawn*:

        - a :class:`gevent.pool.Pool` instance -- ``handle`` will be executed
          using :meth:`gevent.pool.Pool.spawn` only if the pool is not full.
          While it is full, no new connections are accepted;
        - :func:`gevent.spawn_raw` -- ``handle`` will be executed in a raw
          greenlet which has a little less overhead then :class:`gevent.Greenlet` instances spawned by default;
        - ``None`` -- ``handle`` will be executed right away, in the :class:`Hub` greenlet.
          ``handle`` cannot use any blocking functions as it would mean switching to the :class:`Hub`.
        - an integer -- a shortcut for ``gevent.pool.Pool(integer)``

    .. versionchanged:: 1.1a1
       When the *handle* function returns from processing a connection,
       the client socket will be closed. This resolves the non-deterministic
       closing of the socket, fixing ResourceWarnings under Python 3 and PyPy.

    """
    # pylint: disable=too-many-instance-attributes,bare-except,broad-except

    #: the number of seconds to sleep in case there was an error in accept() call
    #: for consecutive errors the delay will double until it reaches max_delay
    #: when accept() finally succeeds the delay will be reset to min_delay again
    min_delay = 0.01
    max_delay = 1

    #: Sets the maximum number of consecutive accepts that a process may perform on
    #: a single wake up. High values give higher priority to high connection rates,
    #: while lower values give higher priority to already established connections.
    #: Default is 100.
    #:
    #: Note that, in case of multiple working processes on the same
    #: listening socket, it should be set to a lower value. (pywsgi.WSGIServer sets it
    #: to 1 when ``environ["wsgi.multiprocess"]`` is true)
    #:
    #: This is equivalent to libuv's `uv_tcp_simultaneous_accepts
    #: <http://docs.libuv.org/en/v1.x/tcp.html#c.uv_tcp_simultaneous_accepts>`_
    #: value. Setting the environment variable UV_TCP_SINGLE_ACCEPT to a true value
    #: (usually 1) changes the default to 1.
    max_accept = 100

    _spawn = Greenlet.spawn

    #: the default timeout that we wait for the client connections to close in stop()
    stop_timeout = 1

    fatal_errors = (errno.EBADF, errno.EINVAL, errno.ENOTSOCK)

    def __init__(self, listener, handle=None, spawn='default'):
        self._stop_event = Event()
        self._stop_event.set()
        self._watcher = None
        self._timer = None
        self._handle = None
        # XXX: FIXME: Subclasses rely on the presence or absence of the
        # `socket` attribute to determine whether we are open/should be opened.
        # Instead, have it be None.
        self.pool = None
        try:
            self.set_listener(listener)
            self.set_spawn(spawn)
            self.set_handle(handle)
            self.delay = self.min_delay
            self.loop = get_hub().loop
            if self.max_accept < 1:
                raise ValueError('max_accept must be positive int: %r' %
                                 (self.max_accept, ))
        except:
            self.close()
            raise

    def set_listener(self, listener):
        if hasattr(listener, 'accept'):
            if hasattr(listener, 'do_handshake'):
                raise TypeError(
                    'Expected a regular socket, not SSLSocket: %r' %
                    (listener, ))
            self.family = listener.family
            self.address = listener.getsockname()
            self.socket = listener
        else:
            self.family, self.address = parse_address(listener)

    def set_spawn(self, spawn):
        if spawn == 'default':
            self.pool = None
            self._spawn = self._spawn
        elif hasattr(spawn, 'spawn'):
            self.pool = spawn
            self._spawn = spawn.spawn
        elif isinstance(spawn, integer_types):
            from gevent.pool import Pool
            self.pool = Pool(spawn)
            self._spawn = self.pool.spawn
        else:
            self.pool = None
            self._spawn = spawn
        if hasattr(self.pool, 'full'):
            self.full = self.pool.full
        if self.pool is not None:
            self.pool._semaphore.rawlink(self._start_accepting_if_started)

    def set_handle(self, handle):
        if handle is not None:
            self.handle = handle
        if hasattr(self, 'handle'):
            self._handle = self.handle
        else:
            raise TypeError("'handle' must be provided")

    def _start_accepting_if_started(self, _event=None):
        if self.started:
            self.start_accepting()

    def start_accepting(self):
        if self._watcher is None:
            # just stop watcher without creating a new one?
            self._watcher = self.loop.io(self.socket.fileno(), 1)
            self._watcher.start(self._do_read)

    def stop_accepting(self):
        if self._watcher is not None:
            self._watcher.stop()
            self._watcher.close()
            self._watcher = None
        if self._timer is not None:
            self._timer.stop()
            self._timer.close()
            self._timer = None

    def do_handle(self, *args):
        spawn = self._spawn
        handle = self._handle
        close = self.do_close

        try:
            if spawn is None:
                _handle_and_close_when_done(handle, close, args)
            else:
                spawn(_handle_and_close_when_done, handle, close, args)
        except:
            close(*args)
            raise

    def do_close(self, *args):
        pass

    def do_read(self):
        raise NotImplementedError()

    def _do_read(self):
        for _ in xrange(self.max_accept):
            if self.full():
                self.stop_accepting()
                return
            try:
                args = self.do_read()
                self.delay = self.min_delay
                if not args:
                    return
            except:
                self.loop.handle_error(self, *sys.exc_info())
                ex = sys.exc_info()[1]
                if self.is_fatal_error(ex):
                    self.close()
                    sys.stderr.write('ERROR: %s failed with %s\n' %
                                     (self, str(ex) or repr(ex)))
                    return
                if self.delay >= 0:
                    self.stop_accepting()
                    self._timer = self.loop.timer(self.delay)
                    self._timer.start(self._start_accepting_if_started)
                    self.delay = min(self.max_delay, self.delay * 2)
                break
            else:
                try:
                    self.do_handle(*args)
                except:
                    self.loop.handle_error((args[1:], self), *sys.exc_info())
                    if self.delay >= 0:
                        self.stop_accepting()
                        self._timer = self.loop.timer(self.delay)
                        self._timer.start(self._start_accepting_if_started)
                        self.delay = min(self.max_delay, self.delay * 2)
                    break

    def full(self):
        # copied from self.pool
        # pylint: disable=method-hidden
        return False

    def __repr__(self):
        return '<%s at %s %s>' % (type(self).__name__, hex(
            id(self)), self._formatinfo())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._formatinfo())

    def _formatinfo(self):
        if hasattr(self, 'socket'):
            try:
                fileno = self.socket.fileno()
            except Exception as ex:
                fileno = str(ex)
            result = 'fileno=%s ' % fileno
        else:
            result = ''
        try:
            if isinstance(self.address, tuple) and len(self.address) == 2:
                result += 'address=%s:%s' % self.address
            else:
                result += 'address=%s' % (self.address, )
        except Exception as ex:
            result += str(ex) or '<error>'

        handle = self.__dict__.get('handle')
        if handle is not None:
            fself = getattr(handle, '__self__', None)
            try:
                if fself is self:
                    # Checks the __self__ of the handle in case it is a bound
                    # method of self to prevent recursively defined reprs.
                    handle_repr = '<bound method %s.%s of self>' % (
                        self.__class__.__name__,
                        handle.__name__,
                    )
                else:
                    handle_repr = repr(handle)

                result += ' handle=' + handle_repr
            except Exception as ex:
                result += str(ex) or '<error>'

        return result

    @property
    def server_host(self):
        """IP address that the server is bound to (string)."""
        if isinstance(self.address, tuple):
            return self.address[0]

    @property
    def server_port(self):
        """Port that the server is bound to (an integer)."""
        if isinstance(self.address, tuple):
            return self.address[1]

    def init_socket(self):
        """
        If the user initialized the server with an address rather than
        socket, then this function must create a socket, bind it, and
        put it into listening mode.

        It is not supposed to be called by the user, it is called by :meth:`start` before starting
        the accept loop.
        """

    @property
    def started(self):
        return not self._stop_event.is_set()

    def start(self):
        """Start accepting the connections.

        If an address was provided in the constructor, then also create a socket,
        bind it and put it into the listening mode.
        """
        self.init_socket()
        self._stop_event.clear()
        try:
            self.start_accepting()
        except:
            self.close()
            raise

    def close(self):
        """Close the listener socket and stop accepting."""
        self._stop_event.set()
        try:
            self.stop_accepting()
        finally:
            try:
                self.socket.close()
            except Exception:
                pass
            finally:
                self.__dict__.pop('socket', None)
                self.__dict__.pop('handle', None)
                self.__dict__.pop('_handle', None)
                self.__dict__.pop('_spawn', None)
                self.__dict__.pop('full', None)
                if self.pool is not None:
                    self.pool._semaphore.unlink(
                        self._start_accepting_if_started)
                    # If the pool's semaphore had a notifier already started,
                    # there's a reference cycle we're a part of
                    # (self->pool->semaphere-hub callback->semaphore)
                    # But we can't destroy self.pool, because self.stop()
                    # calls this method, and then wants to join self.pool()

    @property
    def closed(self):
        return not hasattr(self, 'socket')

    def stop(self, timeout=None):
        """
        Stop accepting the connections and close the listening socket.

        If the server uses a pool to spawn the requests, then
        :meth:`stop` also waits for all the handlers to exit. If there
        are still handlers executing after *timeout* has expired
        (default 1 second, :attr:`stop_timeout`), then the currently
        running handlers in the pool are killed.

        If the server does not use a pool, then this merely stops accepting connections;
        any spawned greenlets that are handling requests continue running until
        they naturally complete.
        """
        self.close()
        if timeout is None:
            timeout = self.stop_timeout
        if self.pool:
            self.pool.join(timeout=timeout)
            self.pool.kill(block=True, timeout=1)

    def serve_forever(self, stop_timeout=None):
        """Start the server if it hasn't been already started and wait until it's stopped."""
        # add test that serve_forever exists on stop()
        if not self.started:
            self.start()
        try:
            self._stop_event.wait()
        finally:
            Greenlet.spawn(self.stop, timeout=stop_timeout).join()

    def is_fatal_error(self, ex):
        return isinstance(ex,
                          _socket.error) and ex.args[0] in self.fatal_errors
예제 #43
0
class RaidenService:
    """ A Raiden node. """
    def __init__(
        self,
        chain,
        default_registry,
        private_key_bin,
        transport,
        config,
        discovery=None,
    ):
        if not isinstance(private_key_bin,
                          bytes) or len(private_key_bin) != 32:
            raise ValueError('invalid private_key')

        invalid_timeout = (
            config['settle_timeout'] < NETTINGCHANNEL_SETTLE_TIMEOUT_MIN
            or config['settle_timeout'] > NETTINGCHANNEL_SETTLE_TIMEOUT_MAX)
        if invalid_timeout:
            raise ValueError('settle_timeout must be in range [{}, {}]'.format(
                NETTINGCHANNEL_SETTLE_TIMEOUT_MIN,
                NETTINGCHANNEL_SETTLE_TIMEOUT_MAX))

        self.tokens_to_connectionmanagers = dict()
        self.identifier_to_results = defaultdict(list)

        self.chain = chain
        self.default_registry = default_registry
        self.config = config
        self.privkey = private_key_bin
        self.address = privatekey_to_address(private_key_bin)

        if config['transport_type'] == 'udp':
            endpoint_registration_event = gevent.spawn(
                discovery.register,
                self.address,
                config['external_ip'],
                config['external_port'],
            )
            endpoint_registration_event.link_exception(
                endpoint_registry_exception_handler)

        self.private_key = PrivateKey(private_key_bin)
        self.pubkey = self.private_key.public_key.format(compressed=False)
        self.protocol = transport

        self.blockchain_events = BlockchainEvents()
        self.alarm = AlarmTask(chain)
        self.shutdown_timeout = config['shutdown_timeout']
        self.stop_event = Event()
        self.start_event = Event()
        self.chain.client.inject_stop_event(self.stop_event)

        self.wal = None

        self.database_path = config['database_path']
        if self.database_path != ':memory:':
            database_dir = os.path.dirname(config['database_path'])
            os.makedirs(database_dir, exist_ok=True)

            self.database_dir = database_dir
            # Prevent concurrent acces to the same db
            self.lock_file = os.path.join(self.database_dir, '.lock')
            self.db_lock = filelock.FileLock(self.lock_file)
        else:
            self.database_path = ':memory:'
            self.database_dir = None
            self.lock_file = None
            self.serialization_file = None
            self.db_lock = None

        if config['transport_type'] == 'udp':
            # If the endpoint registration fails the node will quit, this must
            # finish before starting the protocol
            endpoint_registration_event.join()

        # Lock used to serialize calls to `poll_blockchain_events`, this is
        # important to give a consistent view of the node state.
        self.event_poll_lock = gevent.lock.Semaphore()

        self.start()

    def start(self):
        """ Start the node. """
        if self.stop_event and self.stop_event.is_set():
            self.stop_event.clear()

        if self.database_dir is not None:
            self.db_lock.acquire(timeout=0)
            assert self.db_lock.is_locked

        # The database may be :memory:
        storage = sqlite.SQLiteStorage(self.database_path,
                                       serialize.PickleSerializer())
        self.wal, unapplied_events = wal.restore_from_latest_snapshot(
            node.state_transition,
            storage,
        )

        last_log_block_number = None
        # First run, initialize the basic state
        if self.wal.state_manager.current_state is None:
            block_number = self.chain.block_number()

            state_change = ActionInitNode(
                random.Random(),
                block_number,
            )
            self.wal.log_and_dispatch(state_change, block_number)
        else:
            # Get the last known block number after reapplying all the state changes from the log
            last_log_block_number = views.block_number(
                self.wal.state_manager.current_state)

        # The alarm task must be started after the snapshot is loaded or the
        # state is primed, the callbacks assume the node is initialized.
        self.alarm.start()
        self.alarm.register_callback(self.poll_blockchain_events)
        self.alarm.register_callback(self.set_block_number)

        # Registry registration must start *after* the alarm task. This
        # avoids corner cases where the registry is queried in block A, a new
        # block B is mined, and the alarm starts polling at block C.

        # If last_log_block_number is None, the wal.state_manager.current_state was
        # None in the log, meaning we don't have any events we care about, so just
        # read the latest state from the network
        self.register_payment_network(self.default_registry.address,
                                      last_log_block_number)

        # Start the protocol after the registry is queried to avoid warning
        # about unknown channels.
        queueids_to_queues = views.get_all_messagequeues(
            views.state_from_raiden(self))

        # TODO: remove the cyclic dependency between the protocol and this instance
        self.protocol.start(self, queueids_to_queues)

        # Health check needs the protocol layer
        self.start_neighbours_healthcheck()

        for event in unapplied_events:
            on_raiden_event(self, event)

        self.start_event.set()

    def start_neighbours_healthcheck(self):
        for neighbour in views.all_neighbour_nodes(
                self.wal.state_manager.current_state):
            if neighbour != ConnectionManager.BOOTSTRAP_ADDR:
                self.start_health_check_for(neighbour)

    def stop(self):
        """ Stop the node. """
        # Needs to come before any greenlets joining
        self.stop_event.set()
        self.protocol.stop_and_wait()
        self.alarm.stop_async()

        wait_for = [self.alarm]
        wait_for.extend(getattr(self.protocol, 'greenlets', []))
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        gevent.wait(wait_for, timeout=self.shutdown_timeout)

        # Filters must be uninstalled after the alarm task has stopped. Since
        # the events are polled by an alarm task callback, if the filters are
        # uninstalled before the alarm task is fully stopped the callback
        # `poll_blockchain_events` will fail.
        #
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        try:
            with gevent.Timeout(self.shutdown_timeout):
                self.blockchain_events.uninstall_all_event_listeners()
        except (gevent.timeout.Timeout, RaidenShuttingDown):
            pass

        if self.db_lock is not None:
            self.db_lock.release()

    def __repr__(self):
        return '<{} {}>'.format(self.__class__.__name__, pex(self.address))

    def set_block_number(self, block_number):
        state_change = Block(block_number)
        self.handle_state_change(state_change, block_number)

    def get_block_number(self):
        return views.block_number(self.wal.state_manager.current_state)

    def handle_state_change(self, state_change, block_number=None):
        log.debug('STATE CHANGE',
                  node=pex(self.address),
                  state_change=state_change)

        if block_number is None:
            block_number = self.get_block_number()

        event_list = self.wal.log_and_dispatch(state_change, block_number)

        for event in event_list:
            log.debug('EVENT', node=pex(self.address), chain_event=event)

            on_raiden_event(self, event)

        return event_list

    def set_node_network_state(self, node_address, network_state):
        state_change = ActionChangeNodeNetworkState(node_address,
                                                    network_state)
        self.wal.log_and_dispatch(state_change, self.get_block_number())

    def start_health_check_for(self, node_address):
        self.protocol.start_health_check(node_address)

    def poll_blockchain_events(self, block_number=None):  # pylint: disable=unused-argument
        with self.event_poll_lock:
            for event in self.blockchain_events.poll_blockchain_events():
                on_blockchain_event(self, event)

    def sign(self, message):
        """ Sign message inplace. """
        if not isinstance(message, SignedMessage):
            raise ValueError('{} is not signable.'.format(repr(message)))

        message.sign(self.private_key, self.address)

    def register_payment_network(self, registry_address, from_block=None):
        proxies = get_relevant_proxies(
            self.chain,
            self.address,
            registry_address,
        )

        # Install the filters first to avoid missing changes, as a consequence
        # some events might be applied twice.
        self.blockchain_events.add_proxies_listeners(proxies, from_block)

        token_network_list = list()
        for manager in proxies.channel_managers:
            manager_address = manager.address
            netting_channel_proxies = proxies.channelmanager_nettingchannels[
                manager_address]
            network = get_token_network_state_from_proxies(
                self, manager, netting_channel_proxies)
            token_network_list.append(network)

        payment_network = PaymentNetworkState(
            registry_address,
            token_network_list,
        )

        state_change = ContractReceiveNewPaymentNetwork(payment_network)
        self.handle_state_change(state_change)

    def connection_manager_for_token(self, registry_address, token_address):
        if not is_binary_address(token_address):
            raise InvalidAddress('token address is not valid.')

        known_token_networks = views.get_token_network_addresses_for(
            self.wal.state_manager.current_state,
            registry_address,
        )

        if token_address not in known_token_networks:
            raise InvalidAddress('token is not registered.')

        manager = self.tokens_to_connectionmanagers.get(token_address)

        if manager is None:
            manager = ConnectionManager(self, registry_address, token_address)
            self.tokens_to_connectionmanagers[token_address] = manager

        return manager

    def leave_all_token_networks(self):
        state_change = ActionLeaveAllNetworks()
        self.wal.log_and_dispatch(state_change, self.get_block_number())

    def close_and_settle(self):
        log.info('raiden will close and settle all channels now')

        self.leave_all_token_networks()

        connection_managers = [
            self.tokens_to_connectionmanagers[token_address]
            for token_address in self.tokens_to_connectionmanagers
        ]

        if connection_managers:
            waiting.wait_for_settle_all_channels(
                self,
                self.alarm.wait_time,
            )

    def mediated_transfer_async(
        self,
        token_network_identifier,
        amount,
        target,
        identifier,
    ):
        """ Transfer `amount` between this node and `target`.

        This method will start an asyncronous transfer, the transfer might fail
        or succeed depending on a couple of factors:

            - Existence of a path that can be used, through the usage of direct
              or intermediary channels.
            - Network speed, making the transfer sufficiently fast so it doesn't
              expire.
        """

        async_result = self.start_mediated_transfer(
            token_network_identifier,
            amount,
            target,
            identifier,
        )

        return async_result

    def direct_transfer_async(self, token_network_identifier, amount, target,
                              identifier):
        """ Do a direct transfer with target.

        Direct transfers are non cancellable and non expirable, since these
        transfers are a signed balance proof with the transferred amount
        incremented.

        Because the transfer is non cancellable, there is a level of trust with
        the target. After the message is sent the target is effectively paid
        and then it is not possible to revert.

        The async result will be set to False iff there is no direct channel
        with the target or the payer does not have balance to complete the
        transfer, otherwise because the transfer is non expirable the async
        result *will never be set to False* and if the message is sent it will
        hang until the target node acknowledge the message.

        This transfer should be used as an optimization, since only two packets
        are required to complete the transfer (from the payers perspective),
        whereas the mediated transfer requires 6 messages.
        """

        self.protocol.start_health_check(target)

        if identifier is None:
            identifier = create_default_identifier()

        direct_transfer = ActionTransferDirect(
            token_network_identifier,
            target,
            identifier,
            amount,
        )

        self.handle_state_change(direct_transfer)

    def start_mediated_transfer(
        self,
        token_network_identifier,
        amount,
        target,
        identifier,
    ):

        self.protocol.start_health_check(target)

        if identifier is None:
            identifier = create_default_identifier()

        assert identifier not in self.identifier_to_results

        async_result = AsyncResult()
        self.identifier_to_results[identifier].append(async_result)

        secret = random_secret()
        init_initiator_statechange = initiator_init(
            self,
            identifier,
            amount,
            secret,
            token_network_identifier,
            target,
        )

        # TODO: implement the network timeout raiden.config['msg_timeout'] and
        # cancel the current transfer if it happens (issue #374)
        #
        # Dispatch the state change even if there are no routes to create the
        # wal entry.
        self.handle_state_change(init_initiator_statechange)

        return async_result

    def mediate_mediated_transfer(self, transfer: LockedTransfer):
        init_mediator_statechange = mediator_init(self, transfer)
        self.handle_state_change(init_mediator_statechange)

    def target_mediated_transfer(self, transfer: LockedTransfer):
        init_target_statechange = target_init(transfer)
        self.handle_state_change(init_target_statechange)
예제 #44
0
class KombuAmqpClient(object):
    _SSL_PROTOCOLS = {
        "tlsv1": ssl.PROTOCOL_TLSv1,
        "sslv23": ssl.PROTOCOL_SSLv23
    }

    def __init__(self, logger, config, heartbeat=0):
        self._logger = logger
        servers = re.compile(r'[,\s]+').split(config.servers)
        urls = self._parse_servers(servers, config)
        ssl_params = self._fetch_ssl_params(config)
        self._queue_args = {"x-ha-policy": "all"} if config.ha_mode else None
        self._heartbeat = float(heartbeat)
        self._connection_lock = Semaphore()
        self._consumer_event = Event()
        self._consumers_created_event = Event()
        self._publisher_queue = Queue()
        self._connection = kombu.Connection(
            urls,
            ssl=ssl_params,
            heartbeat=heartbeat,
            transport_options={'confirm_publish': True})
        self._connected = False
        self._exchanges = {}
        self._consumers = {}
        self._removed_consumers = []
        self._running = False
        self._consumers_changed = True
        self._consumer_gl = None
        self._publisher_gl = None
        self._heartbeat_gl = None

    # end __init__

    def get_exchange(self, name):
        return self._exchanges.get(name)

    # end get_exchange

    def add_exchange(self, name, type='direct', durable=False, **kwargs):
        if name in self._exchanges:
            raise ValueError("Exchange with name '%s' already exists" % name)
        exchange = kombu.Exchange(name, type=type, durable=durable, **kwargs)
        self._exchanges[name] = exchange
        return exchange

    # end add_exchange

    def add_consumer(self,
                     name,
                     exchange,
                     routing_key='',
                     callback=None,
                     durable=False,
                     wait=False,
                     **kwargs):
        if name in self._consumers:
            raise ValueError("Consumer with name '%s' already exists" % name)
        exchange_obj = self.get_exchange(exchange)
        queue = kombu.Queue(name,
                            exchange_obj,
                            routing_key=routing_key,
                            durable=durable,
                            **kwargs)
        consumer = AttrDict(queue=queue, callback=callback)
        self._consumers[name] = consumer
        self._consumers_created_event.clear()
        self._consumer_event.set()
        self._consumers_changed = True
        if wait:
            self._consumers_created_event.wait()
        msg = 'KombuAmqpClient: Added consumer: %s' % name
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)
        return consumer

    # end add_consumer

    def remove_consumer(self, name):
        if name not in self._consumers:
            raise ValueError("Consumer with name '%s' does not exist" % name)
        consumer = self._consumers.pop(name)
        self._removed_consumers.append(consumer)
        self._consumer_event.set()
        self._consumers_changed = True
        msg = 'KombuAmqpClient: Removed consumer: %s' % name
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)

    # end remove_consumer

    def publish(self, message, exchange, routing_key=None, **kwargs):
        if message is not None and isinstance(message, basestring) and \
                len(message) == 0:
            message = None
        msg = 'KombuAmqpClient: Publishing message to exchange %s, routing_key %s' % (
            exchange, routing_key)
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)
        self._publisher_queue.put(
            AttrDict(message=message,
                     exchange=exchange,
                     routing_key=routing_key,
                     kwargs=kwargs))

    # end publish

    def run(self):
        self._running = True
        self._consumer_gl = gevent.spawn(self._start_consuming)
        self._publisher_gl = gevent.spawn(self._start_publishing)
        if self._heartbeat:
            self._heartbeat_gl = gevent.spawn(self._heartbeat_check)

    # end run

    def stop(self):
        self._running = False
        if self._heartbeat_gl is not None:
            self._heartbeat_gl.kill()
        if self._publisher_gl is not None:
            self._publisher_gl.kill()
        if self._consumer_gl is not None:
            self._consumer_gl.kill()
        for consumer in (self._removed_consumers +
                         list(self._consumers.values())):
            self._delete_consumer(consumer)
        self._connection.close()

    # end stop

    def _delete_consumer(self, consumer):
        msg = 'KombuAmqpClient: Removing queue %s' % consumer.queue.name
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)
        consumer.queue.maybe_bind(self._connection)
        try:
            consumer.queue.delete(if_unused=True, nowait=False)
        except self._connection.channel_errors:
            pass

    # end _delete_consumer

    def _start_consuming(self):
        errors = (self._connection.connection_errors +
                  self._connection.channel_errors)
        removed_consumer = None
        msg = 'KombuAmqpClient: Starting consumer greenlet'
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)
        while self._running:
            try:
                self._ensure_connection(self._connection, "Consumer")
                self._connected = True

                while len(self._removed_consumers) > 0 or removed_consumer:
                    if removed_consumer is None:
                        removed_consumer = self._removed_consumers.pop(0)
                    self._delete_consumer(removed_consumer)
                    removed_consumer = None

                if len(list(self._consumers.values())) == 0:
                    msg = 'KombuAmqpClient: Waiting for consumer'
                    self._logger(msg, level=SandeshLevel.SYS_DEBUG)
                    self._consumer_event.wait()
                    self._consumer_event.clear()
                    continue

                consumers = [
                    kombu.Consumer(
                        self._connection,
                        queues=c.queue,
                        callbacks=[c.callback] if c.callback else None)
                    for c in list(self._consumers.values())
                ]
                msg = 'KombuAmqpClient: Created consumers %s' % str(
                    list(self._consumers.keys()))
                self._logger(msg, level=SandeshLevel.SYS_DEBUG)
                self._consumers_changed = False
                with nested(*consumers):
                    self._consumers_created_event.set()
                    while self._running and not self._consumers_changed:
                        try:
                            self._connection.drain_events(timeout=1)
                        except socket.timeout:
                            pass
            except errors as e:
                msg = 'KombuAmqpClient: Connection error in Kombu amqp consumer greenlet: %s' % str(
                    e)
                self._logger(msg, level=SandeshLevel.SYS_WARN)
                self._connected = False
                gevent.sleep(0.1)
            except Exception as e:
                msg = 'KombuAmqpClient: Error in Kombu amqp consumer greenlet: %s' % str(
                    e)
                self._logger(msg, level=SandeshLevel.SYS_ERR)
                self._connected = False
                gevent.sleep(0.1)
        msg = 'KombuAmqpClient: Exited consumer greenlet'
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)

    # end _start_consuming

    def _start_publishing(self):
        errors = (self._connection.connection_errors +
                  self._connection.channel_errors)
        payload = None
        connection = self._connection.clone()
        msg = 'KombuAmqpClient: Starting publisher greenlet'
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)
        while self._running:
            try:
                self._ensure_connection(connection, "Publisher")
                producer = kombu.Producer(connection)
                while self._running:
                    if payload is None:
                        payload = self._publisher_queue.get()

                    exchange = self.get_exchange(payload.exchange)
                    producer.publish(payload.message,
                                     exchange=exchange,
                                     routing_key=payload.routing_key,
                                     **payload.kwargs)
                    payload = None
            except errors as e:
                msg = 'KombuAmqpClient: Connection error in Kombu amqp publisher greenlet: %s' % str(
                    e)
                self._logger(msg, level=SandeshLevel.SYS_WARN)
            except Exception as e:
                msg = 'KombuAmqpClient: Error in Kombu amqp publisher greenlet: %s' % str(
                    e)
                self._logger(msg, level=SandeshLevel.SYS_ERR)
        msg = 'KombuAmqpClient: Exiting publisher greenlet'
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)

    # end _start_publishing

    def _heartbeat_check(self):
        while self._running:
            try:
                if self._connected and len(list(self._consumers.values())) > 0:
                    self._connection.heartbeat_check()
            except Exception as e:
                msg = 'KombuAmqpClient: Error in Kombu amqp heartbeat greenlet: %s' % str(
                    e)
                self._logger(msg, level=SandeshLevel.SYS_DEBUG)
            finally:
                gevent.sleep(float(self._heartbeat) / 2)

    # end _heartbeat_check

    def _ensure_connection(self, connection, name):
        msg = 'KombuAmqpClient: Ensuring %s connection' % name
        self._logger(msg, level=SandeshLevel.SYS_DEBUG)
        connection.close()
        connection.ensure_connection()
        connection.connect()
        msg = 'KombuAmqpClient: %s connection established %s' %\
            (name, str(self._connection))
        self._logger(msg, level=SandeshLevel.SYS_INFO)

    # end _ensure_connection

    @staticmethod
    def _parse_servers(servers, config):
        required_keys = ['user', 'password', 'port', 'vhost']
        urls = []
        for server in servers:
            match = re.match(
                r"(?:(?P<user>.*?)(?::(?P<password>.*?))*@)*(?P<host>.*?)(?::(?P<port>\d+))*$",
                server)
            if match:
                host = match.groupdict().copy()
                for key in required_keys:
                    if key not in host or host[key] is None:
                        host[key] = config[key]
                url = "pyamqp://%(user)s:%(password)s@%(host)s:%(port)s/%(vhost)s" % host
                urls.append(url)
        return urls

    # end _parse_servers

    @classmethod
    def _fetch_ssl_params(cls, config):
        if not config.use_ssl:
            return False
        ssl_params = dict()
        if config.ssl_version:
            ssl_params['ssl_version'] = cls._validate_ssl_version(
                config.ssl_version)
        if config.ssl_keyfile:
            ssl_params['keyfile'] = config.ssl_keyfile
        if config.ssl_certfile:
            ssl_params['certfile'] = config.ssl_certfile
        if config.ssl_ca_certs:
            ssl_params['ca_certs'] = config.ssl_ca_certs
            ssl_params['cert_reqs'] = ssl.CERT_REQUIRED
        return ssl_params or True

    # end _fetch_ssl_params

    @classmethod
    def _validate_ssl_version(cls, version):
        version = version.lower()
        try:
            return cls._SSL_PROTOCOLS[version]
        except KeyError:
            raise RuntimeError('Invalid SSL version: {}'.format(version))
예제 #45
0
class Worker(object):
    def __init__(self, stream, gate):
        aj.worker = self
        self.stream = stream
        self.gate = gate
        aj.master = False
        os.setpgrp()
        setproctitle.setproctitle('%s worker [%s]' %
                                  (sys.argv[0], self.gate.name))
        set_log_params(tag=self.gate.log_tag)
        init_log_forwarding(self.send_log_event)

        logging.info(
            'New worker "%s" PID %s, EUID %s, EGID %s',
            self.gate.name,
            os.getpid(),
            os.geteuid(),
            os.getegid(),
        )

        self.context = Context(parent=aj.context)
        self.context.session = self.gate.session
        self.context.worker = self
        self.handler = HttpMiddlewareAggregator([
            AuthenticationMiddleware.get(self.context),
            CentralDispatcher.get(self.context),
        ])

        self._master_config_reloaded = Event()

    def demote(self, uid):
        try:
            username = pwd.getpwuid(uid).pw_name
            gid = pwd.getpwuid(uid).pw_gid
        except KeyError:
            username = None
            gid = uid

        if os.getuid() == uid:
            return
        else:
            if os.getuid() != 0:
                logging.warn(
                    'Running as a limited user, setuid() unavailable!')
                return

        logging.info('Worker %s is demoting to UID %s / GID %s...',
                     os.getpid(), uid, gid)

        groups = [
            g.gr_gid for g in grp.getgrall()
            if username in g.gr_mem or g.gr_gid == gid
        ]
        os.setgroups(groups)
        os.setgid(gid)
        os.setuid(uid)
        logging.info('...done, new EUID %s EGID %s', os.geteuid(),
                     os.getegid())

    def run(self):
        if self.gate.restricted:
            restricted_user = aj.config.data['restricted_user']
            self.demote(pwd.getpwnam(restricted_user).pw_uid)
        else:
            if self.gate.initial_identity:
                AuthenticationService.get(self.context).login(
                    self.gate.initial_identity, demote=True)

        try:
            socket_namespaces = {}
            while True:
                rq = self.stream.recv()

                if not rq:
                    return

                if rq.object['type'] == 'http':
                    gevent.spawn(self.handle_http_request, rq)

                if rq.object['type'] == 'socket':
                    msg = rq.object['message']
                    nsid = rq.object['namespace']
                    event = rq.object['event']

                    if event == 'connect':
                        socket_namespaces[nsid] = WorkerSocketNamespace(
                            self.context, nsid)

                    socket_namespaces[nsid].process_event(event, msg)

                    if event == 'disconnect':
                        socket_namespaces[nsid].destroy()
                        logging.debug(
                            'Socket disconnected, destroying endpoints')

                if rq.object['type'] == 'config-data':
                    logging.debug('Received a config update')
                    aj.config.data = rq.object['data']
                    self._master_config_reloaded.set()

                if rq.object['type'] == 'session-list':
                    logging.debug('Received a session list update')
                    aj.sessions = rq.object['data']

        # pylint: disable=W0703
        except Exception:
            logging.error('Worker crashed!')
            traceback.print_exc()

    def terminate(self):
        self.send_to_upstream({
            'type': 'terminate',
        })

    def update_sessionlist(self):
        self.send_to_upstream({
            'type': 'update-sessionlist',
        })

    def restart_master(self):
        self.send_to_upstream({
            'type': 'restart-master',
        })

    def reload_master_config(self):
        self.send_to_upstream({
            'type': 'reload-config',
        })
        self._master_config_reloaded.wait()
        self._master_config_reloaded.clear()

    def send_log_event(self, method, message, *args, **kwargs):
        self.send_to_upstream({
            'type': 'log',
            'method': method,
            'message': message % args,
            'kwargs': kwargs,
        })

    def handle_http_request(self, rq):
        response_object = {
            'type': 'http',
        }

        try:
            http_context = HttpContext.deserialize(
                rq.object['context'].encode())
            logging.debug('                    ... %s %s', http_context.method,
                          http_context.path)

            # Generate response
            stack = HttpMiddleware.all(self.context)
            content = HttpMiddlewareAggregator(stack + [self.handler]).handle(
                http_context)
            # ---

            http_context.add_header('X-Worker-Name', str(self.gate.name))

            response_object['content'] = list(content)
            response_object['status'] = http_context.status
            response_object['headers'] = http_context.headers
            self.stream.reply(rq, response_object)
        # pylint: disable=W0703
        except Exception as e:
            logging.error(traceback.format_exc())
            response_object.update({
                'error': str(e),
                'exception': repr(e),
            })
            self.stream.reply(rq, response_object)

    def send_to_upstream(self, obj):
        self.stream.reply(None, obj)
예제 #46
0
class UDPTransport(Runnable):
    UDP_MAX_MESSAGE_SIZE = 1200
    log = log
    log_healthcheck = log_healthcheck

    def __init__(self, address, discovery, udpsocket, throttle_policy, config):
        super().__init__()
        # these values are initialized by the start method
        self.queueids_to_queues: Dict = dict()
        self.raiden: RaidenService
        self.message_handler: MessageHandler

        self.discovery = discovery
        self.config = config
        self.address = address

        self.retry_interval = config["retry_interval"]
        self.retries_before_backoff = config["retries_before_backoff"]
        self.nat_keepalive_retries = config["nat_keepalive_retries"]
        self.nat_keepalive_timeout = config["nat_keepalive_timeout"]
        self.nat_invitation_timeout = config["nat_invitation_timeout"]

        self.event_stop = Event()
        self.event_stop.set()

        self.greenlets = list()
        self.addresses_events = dict()

        self.messageids_to_asyncresults = dict()

        # Maps the addresses to a dict with the latest nonce (using a dict
        # because python integers are immutable)
        self.nodeaddresses_to_nonces = dict()

        cache = cachetools.TTLCache(maxsize=50, ttl=CACHE_TTL)
        cache_wrapper = cachetools.cached(cache=cache)
        self.get_host_port = cache_wrapper(discovery.get)

        self.throttle_policy = throttle_policy
        pool = gevent.pool.Pool()
        self.server = DatagramServer(udpsocket,
                                     handle=self.receive,
                                     spawn=pool)

    def start(  # type: ignore
            self,
            raiden_service: RaidenService,
            message_handler: MessageHandler,
            prev_auth_data: str,  # pylint: disable=unused-argument
    ):
        if not self.event_stop.ready():
            raise RuntimeError("UDPTransport started while running")

        self.event_stop.clear()
        self.raiden = raiden_service
        self.log = log.bind(node=pex(self.raiden.address))
        self.log_healthcheck = log_healthcheck.bind(
            node=pex(self.raiden.address))
        self.message_handler = message_handler

        # server.stop() clears the handle and the pool. Since this may be a
        # restart the handle must always be set
        self.server.set_handle(self.receive)
        pool = gevent.pool.Pool()
        self.server.set_spawn(pool)

        self.server.start()
        self.log.debug("UDP started")
        super().start()

    def _run(self):  # pylint: disable=method-hidden
        """ Runnable main method, perform wait on long-running subtasks """
        try:
            self.event_stop.wait()
        except gevent.GreenletExit:  # killed without exception
            self.event_stop.set()
            gevent.killall(self.greenlets)  # kill children
            raise  # re-raise to keep killed status
        except Exception:
            self.stop()  # ensure cleanup and wait on subtasks
            raise

    def stop(self):
        if self.event_stop.ready():
            return  # double call, happens on normal stop, ignore

        self.event_stop.set()

        # Stop handling incoming packets, but don't close the socket. The
        # socket can only be safely closed after all outgoing tasks are stopped
        self.server.stop_accepting()

        # Stop processing the outgoing queues
        gevent.wait(self.greenlets)

        # All outgoing tasks are stopped. Now it's safe to close the socket. At
        # this point there might be some incoming message being processed,
        # keeping the socket open is not useful for these.
        self.server.stop()

        # Calling `.close()` on a gevent socket doesn't actually close the underlying os socket
        # so we do that ourselves here.
        # See: https://github.com/gevent/gevent/blob/master/src/gevent/_socket2.py#L208
        # and: https://groups.google.com/forum/#!msg/gevent/Ro8lRra3nH0/ZENgEXrr6M0J
        try:
            self.server._socket.close()  # pylint: disable=protected-access
        except socket.error:
            pass

        # Set all the pending results to False
        for async_result in self.messageids_to_asyncresults.values():
            async_result.set(False)

        self.log.debug("UDP stopped")
        del self.log_healthcheck
        del self.log

    def get_health_events(self, recipient):
        """ Starts a healthcheck task for `recipient` and returns a
        HealthEvents with locks to react on its current state.
        """
        if recipient not in self.addresses_events:
            self.start_health_check(recipient)

        return self.addresses_events[recipient]

    def whitelist(self, address: Address):  # pylint: disable=no-self-use,unused-argument
        """Whitelist peer address to receive communications from

        This may be called before transport is started, to ensure events generated during
        start are handled properly.
        PS: udp currently doesn't do whitelisting, method defined for compatibility with matrix
        """
        return

    def start_health_check(self, recipient):
        """ Starts a task for healthchecking `recipient` if there is not
        one yet.

        It also whitelists the address
        """
        if recipient not in self.addresses_events:
            self.whitelist(recipient)  # noop for now, for compatibility
            ping_nonce = self.nodeaddresses_to_nonces.setdefault(
                recipient,
                {"nonce": 0}  # HACK: Allows the task to mutate the object
            )

            events = healthcheck.HealthEvents(event_healthy=Event(),
                                              event_unhealthy=Event())

            self.addresses_events[recipient] = events

            greenlet_healthcheck = gevent.spawn(
                healthcheck.healthcheck,
                self,
                recipient,
                self.event_stop,
                events.event_healthy,
                events.event_unhealthy,
                self.nat_keepalive_retries,
                self.nat_keepalive_timeout,
                self.nat_invitation_timeout,
                ping_nonce,
            )
            greenlet_healthcheck.name = f"Healthcheck for {pex(recipient)}"
            greenlet_healthcheck.link_exception(self.on_error)
            self.greenlets.append(greenlet_healthcheck)

    def init_queue_for(self, queue_identifier: QueueIdentifier,
                       items: List[QueueItem_T]) -> NotifyingQueue:
        """ Create the queue identified by the queue_identifier
        and initialize it with `items`.
        """
        recipient = queue_identifier.recipient
        queue = self.queueids_to_queues.get(queue_identifier)
        assert queue is None

        queue = NotifyingQueue(items=items)
        self.queueids_to_queues[queue_identifier] = queue

        events = self.get_health_events(recipient)

        greenlet_queue = gevent.spawn(
            single_queue_send,
            self,
            recipient,
            queue,
            queue_identifier,
            self.event_stop,
            events.event_healthy,
            events.event_unhealthy,
            self.retries_before_backoff,
            self.retry_interval,
            self.retry_interval * 10,
        )

        if queue_identifier.channel_identifier == CHANNEL_IDENTIFIER_GLOBAL_QUEUE:
            greenlet_queue.name = f"Queue for {pex(recipient)} - global"
        else:
            greenlet_queue.name = (
                f"Queue for {pex(recipient)} - {queue_identifier.channel_identifier}"
            )

        greenlet_queue.link_exception(self.on_error)
        self.greenlets.append(greenlet_queue)

        self.log.debug("new queue created for",
                       queue_identifier=queue_identifier,
                       items_qty=len(items))

        return queue

    def get_queue_for(self,
                      queue_identifier: QueueIdentifier) -> NotifyingQueue:
        """ Return the queue identified by the given queue identifier.

        If the queue doesn't exist it will be instantiated.
        """
        queue = self.queueids_to_queues.get(queue_identifier)

        if queue is None:
            items: List[QueueItem_T] = list()
            queue = self.init_queue_for(queue_identifier, items)

        return queue

    def send_async(self, queue_identifier: QueueIdentifier,
                   message: SignedRetrieableMessage):
        """ Send a new ordered message to recipient.

        Messages that use the same `queue_identifier` are ordered.
        """
        recipient = queue_identifier.recipient
        if not is_binary_address(recipient):
            raise ValueError("Invalid address {}".format(pex(recipient)))

        # These are not protocol messages, but transport specific messages
        if isinstance(message, (Delivered, Ping, Pong)):
            raise ValueError("Do not use send for {} messages".format(
                message.__class__.__name__))

        messagedata = message.encode()
        if len(messagedata) > self.UDP_MAX_MESSAGE_SIZE:
            raise ValueError("message size exceeds the maximum {}".format(
                self.UDP_MAX_MESSAGE_SIZE))

        # message identifiers must be unique
        message_id = message.message_identifier

        # ignore duplicates
        if message_id not in self.messageids_to_asyncresults:
            self.messageids_to_asyncresults[message_id] = AsyncResult()

            queue = self.get_queue_for(queue_identifier)
            queue.put((messagedata, message_id))
            assert queue.is_set()

            self.log.debug(
                "Message queued",
                queue_identifier=queue_identifier,
                queue_size=len(queue),
                message=message,
            )

    def send_global(self, room: str, message: Message) -> None:  # pylint: disable=unused-argument
        """ This method exists only for interface compatibility with MatrixTransport """
        self.log.warning("UDP is unable to send global messages. Ignoring")

    def maybe_send(self, recipient: Address, message: Message):
        """ Send message to recipient if the transport is running. """

        if not is_binary_address(recipient):
            raise InvalidAddress("Invalid address {}".format(pex(recipient)))

        messagedata = message.encode()
        host_port = self.get_host_port(recipient)

        self.maybe_sendraw(host_port, messagedata)

    def maybe_sendraw_with_result(self, recipient: Address, messagedata: bytes,
                                  message_id: UDPMessageID) -> AsyncResult:
        """ Send message to recipient if the transport is running.

        Returns:
            An AsyncResult that will be set once the message is delivered. As
            long as the message has not been acknowledged with a Delivered
            message the function will return the same AsyncResult.
        """
        async_result = self.messageids_to_asyncresults.get(message_id)
        if async_result is None:
            async_result = AsyncResult()
            self.messageids_to_asyncresults[message_id] = async_result

        host_port = self.get_host_port(recipient)
        self.maybe_sendraw(host_port, messagedata)

        return async_result

    def maybe_sendraw(self, host_port: Tuple[int, int], messagedata: bytes):
        """ Send message to recipient if the transport is running. """

        # Don't sleep if timeout is zero, otherwise a context-switch is done
        # and the message is delayed, increasing its latency
        sleep_timeout = self.throttle_policy.consume(1)
        if sleep_timeout:
            gevent.sleep(sleep_timeout)

        # Check the udp socket is still available before trying to send the
        # message. There must be *no context-switches after this test*.
        if hasattr(self.server, "socket"):
            self.server.sendto(messagedata, host_port)

    def receive(
            self,
            messagedata: bytes,
            host_port: Tuple[str, int]  # pylint: disable=unused-argument
    ) -> bool:
        """ Handle an UDP packet. """
        # pylint: disable=unidiomatic-typecheck

        if len(messagedata) > self.UDP_MAX_MESSAGE_SIZE:
            self.log.warning(
                "Invalid message: Packet larger than maximum size",
                message=encode_hex(messagedata),
                length=len(messagedata),
            )
            return False

        try:
            message = decode(messagedata)
        except InvalidProtocolMessage as e:
            self.log.warning("Invalid protocol message",
                             error=str(e),
                             message=encode_hex(messagedata))
            return False

        if type(message) == Pong:
            assert isinstance(message, Pong), MYPY_ANNOTATION
            self.receive_pong(message)
        elif type(message) == Ping:
            assert isinstance(message, Ping), MYPY_ANNOTATION
            self.receive_ping(message)
        elif type(message) == Delivered:
            assert isinstance(message, Delivered), MYPY_ANNOTATION
            self.receive_delivered(message)
        elif message is not None:
            assert isinstance(message, SignedRetrieableMessage)
            self.receive_message(message)
        else:
            self.log.warning("Invalid message: Unknown cmdid",
                             message=encode_hex(messagedata))
            return False

        return True

    def receive_message(self, message: SignedRetrieableMessage):
        """ Handle a Raiden protocol message.

        The protocol requires durability of the messages. The UDP transport
        relies on the node's WAL for durability. The message will be converted
        to a state change, saved to the WAL, and *processed* before the
        durability is confirmed, which is a stronger property than what is
        required of any transport.
        """
        self.raiden.on_message(message)

        # Sending Delivered after the message is decoded and *processed*
        # gives a stronger guarantee than what is required from a
        # transport.
        #
        # Alternatives are, from weakest to strongest options:
        # - Just save it on disk and asynchronously process the messages
        # - Decode it, save to the WAL, and asynchronously process the
        #   state change
        # - Decode it, save to the WAL, and process it (the current
        #   implementation)
        delivered_message = Delivered(
            delivered_message_identifier=message.message_identifier)
        self.raiden.sign(delivered_message)

        self.maybe_send(message.sender, delivered_message)

    def receive_delivered(self, delivered: Delivered):
        """ Handle a Delivered message.

        The Delivered message is how the UDP transport guarantees persistence
        by the partner node. The message itself is not part of the raiden
        protocol, but it's required by this transport to provide the required
        properties.
        """
        self.raiden.on_message(delivered)

        message_id = delivered.delivered_message_identifier
        async_result = self.raiden.transport.messageids_to_asyncresults.get(
            message_id)

        # clear the async result, otherwise we have a memory leak
        if async_result is not None:
            del self.messageids_to_asyncresults[message_id]
            async_result.set()
        else:
            self.log.warn("Unknown delivered message received",
                          message_id=message_id)

    # Pings and Pongs are used to check the health status of another node. They
    # are /not/ part of the raiden protocol, only part of the UDP transport,
    # therefore these messages are not forwarded to the message handler.
    def receive_ping(self, ping: Ping):
        """ Handle a Ping message by answering with a Pong. """

        self.log_healthcheck.debug("Ping received",
                                   message_id=ping.nonce,
                                   message=ping,
                                   sender=pex(ping.sender))

        pong = Pong(nonce=ping.nonce)
        self.raiden.sign(pong)

        try:
            self.maybe_send(ping.sender, pong)
        except (InvalidAddress, UnknownAddress) as e:
            self.log.debug("Couldn't send the `Delivered` message", e=e)

    def receive_pong(self, pong: Pong):
        """ Handles a Pong message. """

        message_id = ("ping", pong.nonce, pong.sender)
        async_result = self.messageids_to_asyncresults.get(message_id)

        if async_result is not None:
            self.log_healthcheck.debug("Pong received",
                                       sender=pex(pong.sender),
                                       message_id=pong.nonce)

            async_result.set(True)

        else:
            self.log_healthcheck.warn("Unknown pong received",
                                      message_id=message_id)

    def get_ping(self, nonce: Nonce) -> bytes:
        """ Returns a signed Ping message.

        Note: Ping messages don't have an enforced ordering, so a Ping message
        with a higher nonce may be acknowledged first.
        """
        message = Ping(nonce=nonce,
                       current_protocol_version=constants.PROTOCOL_VERSION)
        self.raiden.sign(message)
        return message.encode()

    def set_node_network_state(self, node_address: Address, node_state):
        state_change = ActionChangeNodeNetworkState(node_address, node_state)
        self.raiden.handle_and_track_state_change(state_change)
예제 #47
0
class uWSGIWebSocket(object):  # pragma: no cover
    """
    This wrapper class provides a uWSGI WebSocket interface that is
    compatible with eventlet's implementation.
    """
    def __init__(self, app):
        self.app = app
        self._sock = None

    def __call__(self, environ, start_response):
        self._sock = uwsgi.connection_fd()
        self.environ = environ

        uwsgi.websocket_handshake()

        self._req_ctx = None
        if hasattr(uwsgi, 'request_context'):
            # uWSGI >= 2.1.x with support for api access across-greenlets
            self._req_ctx = uwsgi.request_context()
        else:
            # use event and queue for sending messages
            from gevent.event import Event
            from gevent.queue import Queue
            from gevent.select import select
            self._event = Event()
            self._send_queue = Queue()

            # spawn a select greenlet
            def select_greenlet_runner(fd, event):
                """Sets event when data becomes available to read on fd."""
                while True:
                    event.set()
                    try:
                        select([fd], [], [])[0]
                    except ValueError:
                        break

            self._select_greenlet = gevent.spawn(select_greenlet_runner,
                                                 self._sock, self._event)

        self.app(self)

    def close(self):
        """Disconnects uWSGI from the client."""
        uwsgi.disconnect()
        if self._req_ctx is None:
            # better kill it here in case wait() is not called again
            self._select_greenlet.kill()
            self._event.set()

    def _send(self, msg):
        """Transmits message either in binary or UTF-8 text mode,
        depending on its type."""
        if isinstance(msg, six.binary_type):
            method = uwsgi.websocket_send_binary
        else:
            method = uwsgi.websocket_send
        if self._req_ctx is not None:
            method(msg, request_context=self._req_ctx)
        else:
            method(msg)

    def _decode_received(self, msg):
        """Returns either bytes or str, depending on message type."""
        if not isinstance(msg, six.binary_type):
            # already decoded - do nothing
            return msg
        # only decode from utf-8 if message is not binary data
        type = six.byte2int(msg[0:1])
        if type >= 48:  # no binary
            return msg.decode('utf-8')
        # binary message, don't try to decode
        return msg

    def send(self, msg):
        """Queues a message for sending. Real transmission is done in
        wait method.
        Sends directly if uWSGI version is new enough."""
        if self._req_ctx is not None:
            self._send(msg)
        else:
            self._send_queue.put(msg)
            self._event.set()

    def wait(self):
        """Waits and returns received messages.
        If running in compatibility mode for older uWSGI versions,
        it also sends messages that have been queued by send().
        A return value of None means that connection was closed.
        This must be called repeatedly. For uWSGI < 2.1.x it must
        be called from the main greenlet."""
        while True:
            if self._req_ctx is not None:
                try:
                    msg = uwsgi.websocket_recv(request_context=self._req_ctx)
                except IOError:  # connection closed
                    return None
                return self._decode_received(msg)
            else:
                # we wake up at least every 3 seconds to let uWSGI
                # do its ping/ponging
                event_set = self._event.wait(timeout=3)
                if event_set:
                    self._event.clear()
                    # maybe there is something to send
                    msgs = []
                    while True:
                        try:
                            msgs.append(self._send_queue.get(block=False))
                        except gevent.queue.Empty:
                            break
                    for msg in msgs:
                        self._send(msg)
                # maybe there is something to receive, if not, at least
                # ensure uWSGI does its ping/ponging
                try:
                    msg = uwsgi.websocket_recv_nb()
                except IOError:  # connection closed
                    self._select_greenlet.kill()
                    return None
                if msg:  # message available
                    return self._decode_received(msg)
예제 #48
0
class BaseServer(object):
    """An abstract base class that implements some common functionality for the servers in gevent.

    *listener* can either be an address that the server should bind on or a :class:`gevent.socket.socket`
    instance that is already bound (and put into listening mode in case of TCP socket).

    *spawn*, if provided, is called to create a new greenlet to run the handler. By default, :func:`gevent.spawn` is used.

    Possible values for *spawn*:

    * a :class:`gevent.pool.Pool` instance -- *handle* will be executed
      using :meth:`Pool.spawn` method only if the pool is not full.
      While it is full, all the connection are dropped;
    * :func:`gevent.spawn_raw` -- *handle* will be executed in a raw
      greenlet which have a little less overhead then :class:`gevent.Greenlet` instances spawned by default;
    * ``None`` -- *handle* will be executed right away, in the :class:`Hub` greenlet.
      *handle* cannot use any blocking functions as it means switching to the :class:`Hub`.
    * an integer -- a shortcut for ``gevent.pool.Pool(integer)``
    """
    # the number of seconds to sleep in case there was an error in accept() call
    # for consecutive errors the delay will double until it reaches max_delay
    # when accept() finally succeeds the delay will be reset to min_delay again
    min_delay = 0.01
    max_delay = 1

    # Sets the maximum number of consecutive accepts that a process may perform on
    # a single wake up. High values give higher priority to high connection rates,
    # while lower values give higher priority to already established connections.
    # Default is 100. Note, that in case of multiple working processes on the same
    # listening value, it should be set to a lower value. (pywsgi.WSGIServer sets it
    # to 1 when environ["wsgi.multiprocess"] is true)
    max_accept = 100

    _spawn = Greenlet.spawn

    # the default timeout that we wait for the client connections to close in stop()
    stop_timeout = 1

    fatal_errors = (errno.EBADF, errno.EINVAL, errno.ENOTSOCK)

    def __init__(self, listener, handle=None, spawn='default'):
        self._stop_event = Event()
        self._stop_event.set()
        self._watcher = None
        self._timer = None
        self.pool = None
        try:
            self.set_listener(listener)
            self.set_spawn(spawn)
            self.set_handle(handle)
            self.delay = self.min_delay
            self.loop = get_hub().loop
            if self.max_accept < 1:
                raise ValueError('max_accept must be positive int: %r' % (self.max_accept, ))
        except:
            self.close()
            raise

    def set_listener(self, listener):
        if hasattr(listener, 'accept'):
            if hasattr(listener, 'do_handshake'):
                raise TypeError('Expected a regular socket, not SSLSocket: %r' % (listener, ))
            self.family = listener.family
            self.address = listener.getsockname()
            self.socket = listener
        else:
            self.family, self.address = parse_address(listener)

    def set_spawn(self, spawn):
        if spawn == 'default':
            self.pool = None
            self._spawn = self._spawn
        elif hasattr(spawn, 'spawn'):
            self.pool = spawn
            self._spawn = spawn.spawn
        elif isinstance(spawn, (int, long)):
            from gevent.pool import Pool
            self.pool = Pool(spawn)
            self._spawn = self.pool.spawn
        else:
            self.pool = None
            self._spawn = spawn
        if hasattr(self.pool, 'full'):
            self.full = self.pool.full
        if self.pool is not None:
            self.pool._semaphore.rawlink(self._start_accepting_if_started)

    def set_handle(self, handle):
        if handle is not None:
            self.handle = handle
        if hasattr(self, 'handle'):
            self._handle = self.handle
        else:
            raise TypeError("'handle' must be provided")

    def _start_accepting_if_started(self, _event=None):
        if self.started:
            self.start_accepting()

    def start_accepting(self):
        if self._watcher is None:
            # just stop watcher without creating a new one?
            self._watcher = self.loop.io(self.socket.fileno(), 1)
            self._watcher.start(self._do_read)

    def stop_accepting(self):
        if self._watcher is not None:
            self._watcher.stop()
            self._watcher = None
        if self._timer is not None:
            self._timer.stop()
            self._timer = None

    def do_handle(self, *args):
        spawn = self._spawn
        if spawn is None:
            self._handle(*args)
        else:
            spawn(self._handle, *args)

    def _do_read(self):
        for _ in xrange(self.max_accept):
            if self.full():
                self.stop_accepting()
                return
            try:
                args = self.do_read()
                self.delay = self.min_delay
                if not args:
                    return
            except:
                self.loop.handle_error(self, *sys.exc_info())
                ex = sys.exc_info()[1]
                if self.is_fatal_error(ex):
                    self.close()
                    sys.stderr.write('ERROR: %s failed with %s\n' % (self, str(ex) or repr(ex)))
                    return
                if self.delay >= 0:
                    self.stop_accepting()
                    self._timer = self.loop.timer(self.delay)
                    self._timer.start(self._start_accepting_if_started)
                    self.delay = min(self.max_delay, self.delay * 2)
                break
            else:
                try:
                    self.do_handle(*args)
                except:
                    self.loop.handle_error((args[1:], self), *sys.exc_info())
                    if self.delay >= 0:
                        self.stop_accepting()
                        self._timer = self.loop.timer(self.delay)
                        self._timer.start(self._start_accepting_if_started)
                        self.delay = min(self.max_delay, self.delay * 2)
                    break

    def full(self):
        return False

    def __repr__(self):
        return '<%s at %s %s>' % (type(self).__name__, hex(id(self)), self._formatinfo())

    def __str__(self):
        return '<%s %s>' % (type(self).__name__, self._formatinfo())

    def _formatinfo(self):
        if hasattr(self, 'socket'):
            try:
                fileno = self.socket.fileno()
            except Exception:
                ex = sys.exc_info()[1]
                fileno = str(ex)
            result = 'fileno=%s ' % fileno
        else:
            result = ''
        try:
            if isinstance(self.address, tuple) and len(self.address) == 2:
                result += 'address=%s:%s' % self.address
            else:
                result += 'address=%s' % (self.address, )
        except Exception:
            ex = sys.exc_info()[1]
            result += str(ex) or '<error>'
        try:
            handle = getfuncname(self.__dict__['handle'])
        except Exception:
            handle = None
        if handle is not None:
            result += ' handle=' + handle
        return result

    @property
    def server_host(self):
        """IP address that the server is bound to (string)."""
        if isinstance(self.address, tuple):
            return self.address[0]

    @property
    def server_port(self):
        """Port that the server is bound to (an integer)."""
        if isinstance(self.address, tuple):
            return self.address[1]

    def init_socket(self):
        """If the user initialized the server with an address rather than socket,
        then this function will create a socket, bind it and put it into listening mode.

        It is not supposed to be called by the user, it is called by :meth:`start` before starting
        the accept loop."""
        pass

    @property
    def started(self):
        return not self._stop_event.is_set()

    def start(self):
        """Start accepting the connections.

        If an address was provided in the constructor, then also create a socket,
        bind it and put it into the listening mode.
        """
        self.init_socket()
        self._stop_event.clear()
        try:
            self.start_accepting()
        except:
            self.close()
            raise

    def close(self):
        """Close the listener socket and stop accepting."""
        self._stop_event.set()
        try:
            self.stop_accepting()
        finally:
            try:
                self.socket.close()
            except Exception:
                pass
            finally:
                self.__dict__.pop('socket', None)
                self.__dict__.pop('handle', None)
                self.__dict__.pop('_handle', None)
                self.__dict__.pop('_spawn', None)
                self.__dict__.pop('full', None)
                if self.pool is not None:
                    self.pool._semaphore.unlink(self._start_accepting_if_started)

    @property
    def closed(self):
        return not hasattr(self, 'socket')

    def stop(self, timeout=None):
        """Stop accepting the connections and close the listening socket.

        If the server uses a pool to spawn the requests, then :meth:`stop` also waits
        for all the handlers to exit. If there are still handlers executing after *timeout*
        has expired (default 1 second), then the currently running handlers in the pool are killed."""
        self.close()
        if timeout is None:
            timeout = self.stop_timeout
        if self.pool:
            self.pool.join(timeout=timeout)
            self.pool.kill(block=True, timeout=1)

    def serve_forever(self, stop_timeout=None):
        """Start the server if it hasn't been already started and wait until it's stopped."""
        # add test that serve_forever exists on stop()
        if not self.started:
            self.start()
        try:
            self._stop_event.wait()
        finally:
            Greenlet.spawn(self.stop, timeout=stop_timeout).join()

    def is_fatal_error(self, ex):
        return isinstance(ex, _socket.error) and ex[0] in self.fatal_errors
예제 #49
0
class WSHandler(object):
    def __init__(self, environ, start_response):
        global reply_channels

        self.ws = environ['wsgi.websocket']

        # Make a name for our reply channel
        self.reply_channel = default_channel_layer.new_channel(
            u"websocket.send." + default_channel_layer.client_prefix + "!")
        self.last_keepalive = time.time()

        self.data = []
        self.ev = Event()
        self.rc_name = self.reply_channel.split("!")[1]
        reply_channels[self.rc_name] = self

        clean_headers = []
        for k, v in environ.iteritems():
            if (k.startswith("HTTP_")):
                clean_headers.append([k[5:].lower(), v])
            elif k.lower().startswith("content"):
                clean_headers.append([k.lower().replace("_", "-"), v])

        self.packets_received = 0

        self.request_info = {
            "path": environ['PATH_INFO'],
            "headers": clean_headers,
            "query_string": environ['QUERY_STRING'],
            "client": environ['REMOTE_ADDR'],
            "server": environ['SERVER_NAME'],
            "reply_channel": self.reply_channel,
            "order": self.packets_received,
        }

        default_channel_layer.send("websocket.connect", self.request_info)

        gevent.spawn(ws_listen, self)

        while (True):
            self.ev.wait()
            self.ev.clear()
            while len(self.data) > 0:
                d = self.data.pop(0)
                #print d
                if d == None:
                    break  #TODO: notify about disconnect
                if 'conn_key' in d:
                    del d[u'conn_key']
                    if 'text' in d:
                        try:
                            self.ws.send_frame(d['text'],
                                               WebSocket.OPCODE_TEXT)
                        except:
                            break
                else:
                    self.packets_received += 1

                    self.request_info = {
                        "path": environ['PATH_INFO'],
                        "headers": clean_headers,
                        "query_string": environ['QUERY_STRING'],
                        "client": environ['REMOTE_ADDR'],
                        "server": environ['SERVER_NAME'],
                        "reply_channel": self.reply_channel,
                        "order": self.packets_received,
                        "text": d
                    }

                    default_channel_layer.send("websocket.receive",
                                               self.request_info)

        self.packets_received += 1
        self.request_info = {
            "path": environ['PATH_INFO'],
            "headers": clean_headers,
            "query_string": environ['QUERY_STRING'],
            "client": environ['REMOTE_ADDR'],
            "server": environ['SERVER_NAME'],
            "reply_channel": self.reply_channel,
            "order": self.packets_received,
        }
        default_channel_layer.send("websocket.disconnect", self.request_info)

    def listen(self):
        while True:
            v = None
            try:
                v = self.ws.read_message()
                if v == None:
                    continue
            except:
                return
            self.data.append(v)
            self.ev.set()

    def notify(self, content):
        self.data.append(content)
        self.ev.set()

    def clean_up(self):
        del reply_channels[self.rc_name]
예제 #50
0
class Console(BaseService):
    """A service starting an interactive ipython session when receiving the
    SIGSTP signal (e.g. via keyboard shortcut CTRL-Z).
    """

    name = 'console'

    def __init__(self, app):
        super(Console, self).__init__(app)
        self.interrupt = Event()
        self.console_locals = {}
        if app.start_console:
            self.start()
            self.interrupt.set()
        else:
            SigINTHandler(self.interrupt)

    def start(self):
        # start console service
        super(Console, self).start()

        class Raiden(object):
            def __init__(self, app):
                self.app = app

        self.console_locals = dict(
            _raiden=Raiden(self.app),
            raiden=self.app.raiden,
            chain=self.app.raiden.chain,
            discovery=self.app.discovery,
            tools=ConsoleTools(
                self.app.raiden,
                self.app.discovery,
                self.app.config['settle_timeout'],
                self.app.config['reveal_timeout'],
            ),
            denoms=denoms,
            true=True,
            false=False,
            usage=print_usage,
        )

    def _run(self):
        self.interrupt.wait()
        print('\n' * 2)
        print("Entering Console" + bc.OKGREEN)
        print("Tip:" + bc.OKBLUE)
        print_usage()

        # Remove handlers that log to stderr
        root = getLogger()
        for handler in root.handlers[:]:
            if isinstance(handler,
                          StreamHandler) and handler.stream == sys.stderr:
                root.removeHandler(handler)

        stream = cStringIO.StringIO()
        handler = StreamHandler(stream=stream)
        handler.formatter = Formatter("%(levelname)s:%(name)s %(message)s")
        root.addHandler(handler)

        def lastlog(n=10, prefix=None, level=None):
            """Print the last `n` log lines to stdout.
            Use `prefix='p2p'` to filter for a specific logger.
            Use `level=INFO` to filter for a specific level.
            Level- and prefix-filtering are applied before tailing the log.
            """
            lines = (stream.getvalue().strip().split('\n') or [])
            if prefix:
                lines = filter(
                    lambda line: line.split(':')[1].startswith(prefix), lines)
            if level:
                lines = filter(lambda line: line.split(':')[0] == level, lines)
            for line in lines[-n:]:
                print(line)

        self.console_locals['lastlog'] = lastlog

        err = cStringIO.StringIO()
        sys.stderr = err

        def lasterr(n=1):
            """Print the last `n` entries of stderr to stdout.
            """
            for line in (err.getvalue().strip().split('\n') or [])[-n:]:
                print(line)

        self.console_locals['lasterr'] = lasterr

        IPython.start_ipython(argv=['--gui', 'gevent'],
                              user_ns=self.console_locals)
        self.interrupt.clear()

        sys.exit(0)
예제 #51
0
class RaidenService:
    """ A Raiden node. """

    # pylint: disable=too-many-instance-attributes,too-many-public-methods

    def __init__(self, chain, default_registry, private_key_bin, transport,
                 discovery, config):
        if not isinstance(private_key_bin,
                          bytes) or len(private_key_bin) != 32:
            raise ValueError('invalid private_key')

        invalid_timeout = (
            config['settle_timeout'] < NETTINGCHANNEL_SETTLE_TIMEOUT_MIN
            or config['settle_timeout'] > NETTINGCHANNEL_SETTLE_TIMEOUT_MAX)
        if invalid_timeout:
            raise ValueError('settle_timeout must be in range [{}, {}]'.format(
                NETTINGCHANNEL_SETTLE_TIMEOUT_MIN,
                NETTINGCHANNEL_SETTLE_TIMEOUT_MAX))

        self.token_to_channelgraph = dict()
        self.tokens_to_connectionmanagers = dict()
        self.manager_to_token = dict()
        self.swapkey_to_tokenswap = dict()
        self.swapkey_to_greenlettask = dict()

        self.identifier_to_statemanagers = defaultdict(list)
        self.identifier_to_results = defaultdict(list)

        # This is a map from a hashlock to a list of channels, the same
        # hashlock can be used in more than one token (for tokenswaps), a
        # channel should be removed from this list only when the lock is
        # released/withdrawn but not when the secret is registered.
        self.token_to_hashlock_to_channels = defaultdict(
            lambda: defaultdict(list))

        self.chain = chain
        self.default_registry = default_registry
        self.config = config
        self.privkey = private_key_bin
        self.address = privatekey_to_address(private_key_bin)

        endpoint_registration_event = gevent.spawn(
            discovery.register,
            self.address,
            config['external_ip'],
            config['external_port'],
        )
        endpoint_registration_event.link_exception(
            endpoint_registry_exception_handler)

        self.private_key = PrivateKey(private_key_bin)
        self.pubkey = self.private_key.public_key.format(compressed=False)
        self.protocol = RaidenProtocol(
            transport,
            discovery,
            self,
            config['protocol']['retry_interval'],
            config['protocol']['retries_before_backoff'],
            config['protocol']['nat_keepalive_retries'],
            config['protocol']['nat_keepalive_timeout'],
            config['protocol']['nat_invitation_timeout'],
        )

        # TODO: remove this cyclic dependency
        transport.protocol = self.protocol

        self.message_handler = RaidenMessageHandler(self)
        self.state_machine_event_handler = StateMachineEventHandler(self)
        self.blockchain_events = BlockchainEvents()
        self.greenlet_task_dispatcher = GreenletTasksDispatcher()
        self.on_message = self.message_handler.on_message
        self.alarm = AlarmTask(chain)
        self.shutdown_timeout = config['shutdown_timeout']
        self._block_number = None
        self.stop_event = Event()
        self.start_event = Event()
        self.chain.client.inject_stop_event(self.stop_event)

        self.transaction_log = StateChangeLog(
            storage_instance=StateChangeLogSQLiteBackend(
                database_path=config['database_path']))

        if config['database_path'] != ':memory:':
            self.database_dir = os.path.dirname(config['database_path'])
            self.lock_file = os.path.join(self.database_dir, '.lock')
            self.snapshot_dir = os.path.join(self.database_dir, 'snapshots')
            self.serialization_file = os.path.join(self.snapshot_dir,
                                                   'data.pickle')

            if not os.path.exists(self.snapshot_dir):
                os.makedirs(self.snapshot_dir)

            # Prevent concurrent acces to the same db
            self.db_lock = filelock.FileLock(self.lock_file)
        else:
            self.database_dir = None
            self.lock_file = None
            self.snapshot_dir = None
            self.serialization_file = None
            self.db_lock = None

        # If the endpoint registration fails the node will quit, this must
        # finish before starting the protocol
        endpoint_registration_event.join()

        self.start()

    def start(self):
        """ Start the node. """
        # XXX Should this really be here? Or will start() never be called again
        # after stop() in the lifetime of Raiden apart from the tests? This is
        # at least at the moment prompted by tests/integration/test_transer.py
        if self.stop_event and self.stop_event.is_set():
            self.stop_event.clear()

        self.alarm.start()

        # Prime the block number cache and set the callbacks
        self._block_number = self.alarm.last_block_number
        self.alarm.register_callback(self.poll_blockchain_events)
        self.alarm.register_callback(self.set_block_number)

        # Registry registration must start *after* the alarm task, this avoid
        # corner cases were the registry is queried in block A, a new block B
        # is mined, and the alarm starts polling at block C.
        self.register_registry(self.default_registry.address)

        # Restore from snapshot must come after registering the registry as we
        # need to know the registered tokens to populate `token_to_channelgraph`
        if self.database_dir is not None:
            self.db_lock.acquire(timeout=0)
            assert self.db_lock.is_locked
            self.restore_from_snapshots()

        # Start the protocol after the registry is queried to avoid warning
        # about unknown channels.
        self.protocol.start()

        # Health check needs the protocol layer
        self.start_neighbours_healthcheck()

        self.start_event.set()

    def start_neighbours_healthcheck(self):
        for graph in self.token_to_channelgraph.values():
            for neighbour in graph.get_neighbours():
                if neighbour != ConnectionManager.BOOTSTRAP_ADDR:
                    self.start_health_check_for(neighbour)

    def stop(self):
        """ Stop the node. """
        # Needs to come before any greenlets joining
        self.stop_event.set()
        self.protocol.stop_and_wait()
        self.alarm.stop_async()

        wait_for = [self.alarm]
        wait_for.extend(self.protocol.greenlets)
        wait_for.extend(self.greenlet_task_dispatcher.stop())
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        gevent.wait(wait_for, timeout=self.shutdown_timeout)

        # Filters must be uninstalled after the alarm task has stopped. Since
        # the events are polled by an alarm task callback, if the filters are
        # uninstalled before the alarm task is fully stopped the callback
        # `poll_blockchain_events` will fail.
        #
        # We need a timeout to prevent an endless loop from trying to
        # contact the disconnected client
        try:
            with gevent.Timeout(self.shutdown_timeout):
                self.blockchain_events.uninstall_all_event_listeners()
        except (gevent.timeout.Timeout, RaidenShuttingDown):
            pass

        # save the state after all tasks are done
        if self.serialization_file:
            save_snapshot(self.serialization_file, self)

        if self.db_lock is not None:
            self.db_lock.release()

    def __repr__(self):
        return '<{} {}>'.format(self.__class__.__name__, pex(self.address))

    def restore_from_snapshots(self):
        data = load_snapshot(self.serialization_file)
        data_exists_and_is_recent = (data is not None
                                     and 'registry_address' in data
                                     and data['registry_address']
                                     == ROPSTEN_REGISTRY_ADDRESS)

        if data_exists_and_is_recent:
            first_channel = True
            for channel in data['channels']:
                try:
                    self.restore_channel(channel)
                    first_channel = False
                except AddressWithoutCode as e:
                    log.warn(
                        'Channel without code while restoring. Must have been '
                        'already settled while we were offline.',
                        error=str(e))
                except AttributeError as e:
                    if first_channel:
                        log.warn(
                            'AttributeError during channel restoring. If code has changed'
                            ' then this is fine. If not then please report a bug.',
                            error=str(e))
                        break
                    else:
                        raise

            for restored_queue in data['queues']:
                self.restore_queue(restored_queue)

            self.protocol.receivedhashes_to_acks = data[
                'receivedhashes_to_acks']
            self.protocol.nodeaddresses_to_nonces = data[
                'nodeaddresses_to_nonces']

            self.restore_transfer_states(data['transfers'])

    def set_block_number(self, block_number):
        state_change = Block(block_number)
        self.state_machine_event_handler.log_and_dispatch_to_all_tasks(
            state_change)

        for graph in self.token_to_channelgraph.values():
            for channel in graph.address_to_channel.values():
                channel.state_transition(state_change)

        # To avoid races, only update the internal cache after all the state
        # tasks have been updated.
        self._block_number = block_number

    def set_node_network_state(self, node_address, network_state):
        for graph in self.token_to_channelgraph.values():
            channel = graph.partneraddress_to_channel.get(node_address)

            if channel:
                channel.network_state = network_state

    def start_health_check_for(self, node_address):
        self.protocol.start_health_check(node_address)

    def get_block_number(self):
        return self._block_number

    def poll_blockchain_events(self, current_block=None):
        # pylint: disable=unused-argument
        on_statechange = self.state_machine_event_handler.on_blockchain_statechange

        for state_change in self.blockchain_events.poll_state_change(
                self._block_number):
            on_statechange(state_change)

    def find_channel_by_address(self, netting_channel_address_bin):
        for graph in self.token_to_channelgraph.values():
            channel = graph.address_to_channel.get(netting_channel_address_bin)

            if channel is not None:
                return channel

        raise ValueError('unknown channel {}'.format(
            encode_hex(netting_channel_address_bin)))

    def sign(self, message):
        """ Sign message inplace. """
        if not isinstance(message, SignedMessage):
            raise ValueError('{} is not signable.'.format(repr(message)))

        message.sign(self.private_key, self.address)

    def send_async(self, recipient, message):
        """ Send `message` to `recipient` using the raiden protocol.

        The protocol will take care of resending the message on a given
        interval until an Acknowledgment is received or a given number of
        tries.
        """

        if not isaddress(recipient):
            raise ValueError('recipient is not a valid address.')

        if recipient == self.address:
            raise ValueError('programming error, sending message to itself')

        return self.protocol.send_async(recipient, message)

    def send_and_wait(self, recipient, message, timeout):
        """ Send `message` to `recipient` and wait for the response or `timeout`.

        Args:
            recipient (address): The address of the node that will receive the
                message.
            message: The transfer message.
            timeout (float): How long should we wait for a response from `recipient`.

        Returns:
            None: If the wait timed out
            object: The result from the event
        """
        if not isaddress(recipient):
            raise ValueError('recipient is not a valid address.')

        self.protocol.send_and_wait(recipient, message, timeout)

    def register_secret(self, secret: bytes):
        """ Register the secret with any channel that has a hashlock on it.

        This must search through all channels registered for a given hashlock
        and ignoring the tokens. Useful for refund transfer, split transfer,
        and token swaps.

        Raises:
            TypeError: If secret is unicode data.
        """
        if not isinstance(secret, bytes):
            raise TypeError('secret must be bytes')

        hashlock = sha3(secret)
        revealsecret_message = RevealSecret(secret)
        self.sign(revealsecret_message)

        for hash_channel in self.token_to_hashlock_to_channels.values():
            for channel in hash_channel[hashlock]:
                channel.register_secret(secret)

                # The protocol ignores duplicated messages.
                self.send_async(
                    channel.partner_state.address,
                    revealsecret_message,
                )

    def register_channel_for_hashlock(self, token_address, channel, hashlock):
        channels_registered = self.token_to_hashlock_to_channels[
            token_address][hashlock]

        if channel not in channels_registered:
            channels_registered.append(channel)

    def handle_secret(  # pylint: disable=too-many-arguments
            self, identifier, token_address, secret, partner_secret_message,
            hashlock):
        """ Unlock/Witdraws locks, register the secret, and send Secret
        messages as necessary.

        This function will:
            - Unlock the locks created by this node and send a Secret message to
            the corresponding partner so that she can withdraw the token.
            - Withdraw the lock from sender.
            - Register the secret for the locks received and reveal the secret
            to the senders


        Note:
            The channel needs to be registered with
            `raiden.register_channel_for_hashlock`.
        """
        # handling the secret needs to:
        # - unlock the token for all `forward_channel` (the current one
        #   and the ones that failed with a refund)
        # - send a message to each of the forward nodes allowing them
        #   to withdraw the token
        # - register the secret for the `originating_channel` so that a
        #   proof can be made, if necessary
        # - reveal the secret to the `sender` node (otherwise we
        #   cannot withdraw the token)
        channels_list = self.token_to_hashlock_to_channels[token_address][
            hashlock]
        channels_to_remove = list()

        revealsecret_message = RevealSecret(secret)
        self.sign(revealsecret_message)

        messages_to_send = []
        for channel in channels_list:
            # unlock a pending lock
            if channel.our_state.is_known(hashlock):
                secret = channel.create_secret(identifier, secret)
                self.sign(secret)

                channel.register_transfer(
                    self.get_block_number(),
                    secret,
                )

                messages_to_send.append((
                    channel.partner_state.address,
                    secret,
                ))

                channels_to_remove.append(channel)

            # withdraw a pending lock
            elif channel.partner_state.is_known(hashlock):
                if partner_secret_message:
                    is_balance_proof = (partner_secret_message.sender
                                        == channel.partner_state.address
                                        and partner_secret_message.channel
                                        == channel.channel_address)

                    if is_balance_proof:
                        channel.register_transfer(
                            self.get_block_number(),
                            partner_secret_message,
                        )
                        channels_to_remove.append(channel)
                    else:
                        channel.register_secret(secret)
                        messages_to_send.append((
                            channel.partner_state.address,
                            revealsecret_message,
                        ))
                else:
                    channel.register_secret(secret)
                    messages_to_send.append((
                        channel.partner_state.address,
                        revealsecret_message,
                    ))

            else:
                log.error(
                    'Channel is registered for a given lock but the lock is not contained in it.'
                )

        for channel in channels_to_remove:
            channels_list.remove(channel)

        if not channels_list:
            del self.token_to_hashlock_to_channels[token_address][hashlock]

        # send the messages last to avoid races
        for recipient, message in messages_to_send:
            self.send_async(
                recipient,
                message,
            )

    def get_channel_details(self, token_address, netting_channel):
        channel_details = netting_channel.detail()
        our_state = ChannelEndState(
            channel_details['our_address'],
            channel_details['our_balance'],
            None,
            EMPTY_MERKLE_TREE,
        )
        partner_state = ChannelEndState(
            channel_details['partner_address'],
            channel_details['partner_balance'],
            None,
            EMPTY_MERKLE_TREE,
        )

        def register_channel_for_hashlock(channel, hashlock):
            self.register_channel_for_hashlock(
                token_address,
                channel,
                hashlock,
            )

        channel_address = netting_channel.address
        reveal_timeout = self.config['reveal_timeout']
        settle_timeout = channel_details['settle_timeout']

        external_state = ChannelExternalState(
            register_channel_for_hashlock,
            netting_channel,
        )

        channel_detail = ChannelDetails(
            channel_address,
            our_state,
            partner_state,
            external_state,
            reveal_timeout,
            settle_timeout,
        )

        return channel_detail

    def restore_channel(self, serialized_channel):
        token_address = serialized_channel.token_address

        netting_channel = self.chain.netting_channel(
            serialized_channel.channel_address, )

        # restoring balances from the blockchain since the serialized
        # value could be falling behind.
        channel_details = netting_channel.detail()

        # our_address is checked by detail
        assert channel_details[
            'partner_address'] == serialized_channel.partner_address

        if serialized_channel.our_leaves:
            our_layers = compute_layers(serialized_channel.our_leaves)
            our_tree = MerkleTreeState(our_layers)
        else:
            our_tree = EMPTY_MERKLE_TREE

        our_state = ChannelEndState(
            channel_details['our_address'],
            channel_details['our_balance'],
            serialized_channel.our_balance_proof,
            our_tree,
        )

        if serialized_channel.partner_leaves:
            partner_layers = compute_layers(serialized_channel.partner_leaves)
            partner_tree = MerkleTreeState(partner_layers)
        else:
            partner_tree = EMPTY_MERKLE_TREE

        partner_state = ChannelEndState(
            channel_details['partner_address'],
            channel_details['partner_balance'],
            serialized_channel.partner_balance_proof,
            partner_tree,
        )

        def register_channel_for_hashlock(channel, hashlock):
            self.register_channel_for_hashlock(
                token_address,
                channel,
                hashlock,
            )

        external_state = ChannelExternalState(
            register_channel_for_hashlock,
            netting_channel,
        )
        details = ChannelDetails(
            serialized_channel.channel_address,
            our_state,
            partner_state,
            external_state,
            serialized_channel.reveal_timeout,
            channel_details['settle_timeout'],
        )

        graph = self.token_to_channelgraph[token_address]
        graph.add_channel(details)
        channel = graph.address_to_channel.get(
            serialized_channel.channel_address, )

        channel.our_state.balance_proof = serialized_channel.our_balance_proof
        channel.partner_state.balance_proof = serialized_channel.partner_balance_proof

    def restore_queue(self, serialized_queue):
        receiver_address = serialized_queue['receiver_address']
        token_address = serialized_queue['token_address']

        queue = self.protocol.get_channel_queue(
            receiver_address,
            token_address,
        )

        for messagedata in serialized_queue['messages']:
            queue.put(messagedata)

    def restore_transfer_states(self, transfer_states):
        self.identifier_to_statemanagers = transfer_states

    def register_registry(self, registry_address):
        proxies = get_relevant_proxies(
            self.chain,
            self.address,
            registry_address,
        )

        # Install the filters first to avoid missing changes, as a consequence
        # some events might be applied twice.
        self.blockchain_events.add_proxies_listeners(proxies)

        for manager in proxies.channel_managers:
            token_address = manager.token_address()
            manager_address = manager.address

            channels_detail = list()
            netting_channels = proxies.channelmanager_nettingchannels[
                manager_address]
            for channel in netting_channels:
                detail = self.get_channel_details(token_address, channel)
                channels_detail.append(detail)

            edge_list = manager.channels_addresses()
            graph = ChannelGraph(
                self.address,
                manager_address,
                token_address,
                edge_list,
                channels_detail,
            )

            self.manager_to_token[manager_address] = token_address
            self.token_to_channelgraph[token_address] = graph

            self.tokens_to_connectionmanagers[
                token_address] = ConnectionManager(self, token_address, graph)

    def channel_manager_is_registered(self, manager_address):
        return manager_address in self.manager_to_token

    def register_channel_manager(self, manager_address):
        manager = self.default_registry.manager(manager_address)
        netting_channels = [
            self.chain.netting_channel(channel_address) for channel_address in
            manager.channels_by_participant(self.address)
        ]

        # Install the filters first to avoid missing changes, as a consequence
        # some events might be applied twice.
        self.blockchain_events.add_channel_manager_listener(manager)
        for channel in netting_channels:
            self.blockchain_events.add_netting_channel_listener(channel)

        token_address = manager.token_address()
        edge_list = manager.channels_addresses()
        channels_detail = [
            self.get_channel_details(token_address, channel)
            for channel in netting_channels
        ]

        graph = ChannelGraph(
            self.address,
            manager_address,
            token_address,
            edge_list,
            channels_detail,
        )

        self.manager_to_token[manager_address] = token_address
        self.token_to_channelgraph[token_address] = graph

        self.tokens_to_connectionmanagers[token_address] = ConnectionManager(
            self, token_address, graph)

    def register_netting_channel(self, token_address, channel_address):
        netting_channel = self.chain.netting_channel(channel_address)
        self.blockchain_events.add_netting_channel_listener(netting_channel)

        detail = self.get_channel_details(token_address, netting_channel)
        graph = self.token_to_channelgraph[token_address]
        graph.add_channel(detail)

    def connection_manager_for_token(self, token_address):
        if not isaddress(token_address):
            raise InvalidAddress('token address is not valid.')
        if token_address in self.tokens_to_connectionmanagers.keys():
            manager = self.tokens_to_connectionmanagers[token_address]
        else:
            raise InvalidAddress('token is not registered.')
        return manager

    def leave_all_token_networks_async(self):
        leave_results = []
        for token_address in self.token_to_channelgraph.keys():
            try:
                connection_manager = self.connection_manager_for_token(
                    token_address)
                leave_results.append(connection_manager.leave_async())
            except InvalidAddress:
                pass
        combined_result = AsyncResult()
        gevent.spawn(gevent.wait, leave_results).link(combined_result)
        return combined_result

    def close_and_settle(self):
        log.info('raiden will close and settle all channels now')

        connection_managers = [
            self.connection_manager_for_token(token_address)
            for token_address in self.token_to_channelgraph
        ]

        def blocks_to_wait():
            return max(connection_manager.min_settle_blocks
                       for connection_manager in connection_managers)

        all_channels = list(
            itertools.chain.from_iterable([
                connection_manager.open_channels
                for connection_manager in connection_managers
            ]))

        leaving_greenlet = self.leave_all_token_networks_async()
        # using the un-cached block number here
        last_block = self.chain.block_number()

        earliest_settlement = last_block + blocks_to_wait()

        # TODO: estimate and set a `timeout` parameter in seconds
        # based on connection_manager.min_settle_blocks and an average
        # blocktime from the past

        current_block = last_block
        while current_block < earliest_settlement:
            gevent.sleep(self.alarm.wait_time)
            last_block = self.chain.block_number()
            if last_block != current_block:
                current_block = last_block
                avg_block_time = self.chain.estimate_blocktime()
                wait_blocks_left = blocks_to_wait()
                not_settled = sum(
                    1 for channel in all_channels
                    if not channel.state == CHANNEL_STATE_SETTLED)
                if not_settled == 0:
                    log.debug('nothing left to settle')
                    break
                log.info(
                    'waiting at least %s more blocks (~%s sec) for settlement'
                    '(%s channels not yet settled)' %
                    (wait_blocks_left, wait_blocks_left * avg_block_time,
                     not_settled))

            leaving_greenlet.wait(timeout=blocks_to_wait() *
                                  self.chain.estimate_blocktime() * 1.5)

        if any(channel.state != CHANNEL_STATE_SETTLED
               for channel in all_channels):
            log.error('Some channels were not settled!',
                      channels=[
                          pex(channel.channel_address)
                          for channel in all_channels
                          if channel.state != CHANNEL_STATE_SETTLED
                      ])

    def mediated_transfer_async(self, token_address, amount, target,
                                identifier):
        """ Transfer `amount` between this node and `target`.

        This method will start an asyncronous transfer, the transfer might fail
        or succeed depending on a couple of factors:

            - Existence of a path that can be used, through the usage of direct
              or intermediary channels.
            - Network speed, making the transfer sufficiently fast so it doesn't
              expire.
        """

        async_result = self.start_mediated_transfer(
            token_address,
            amount,
            identifier,
            target,
        )

        return async_result

    def direct_transfer_async(self, token_address, amount, target, identifier):
        """ Do a direct tranfer with target.

        Direct transfers are non cancellable and non expirable, since these
        transfers are a signed balance proof with the transferred amount
        incremented.

        Because the transfer is non cancellable, there is a level of trust with
        the target. After the message is sent the target is effectively paid
        and then it is not possible to revert.

        The async result will be set to False iff there is no direct channel
        with the target or the payer does not have balance to complete the
        transfer, otherwise because the transfer is non expirable the async
        result *will never be set to False* and if the message is sent it will
        hang until the target node acknowledge the message.

        This transfer should be used as an optimization, since only two packets
        are required to complete the transfer (from the payer's perspective),
        whereas the mediated transfer requires 6 messages.
        """
        graph = self.token_to_channelgraph[token_address]
        direct_channel = graph.partneraddress_to_channel.get(target)

        direct_channel_with_capacity = (direct_channel
                                        and direct_channel.can_transfer and
                                        amount <= direct_channel.distributable)

        if direct_channel_with_capacity:
            direct_transfer = direct_channel.create_directtransfer(
                amount, identifier)
            self.sign(direct_transfer)
            direct_channel.register_transfer(
                self.get_block_number(),
                direct_transfer,
            )

            direct_transfer_state_change = ActionTransferDirect(
                identifier,
                amount,
                token_address,
                direct_channel.partner_state.address,
            )
            # TODO: add the transfer sent event
            state_change_id = self.transaction_log.log(
                direct_transfer_state_change)

            # TODO: This should be set once the direct transfer is acknowledged
            transfer_success = EventTransferSentSuccess(
                identifier,
                amount,
                target,
            )
            self.transaction_log.log_events(state_change_id,
                                            [transfer_success],
                                            self.get_block_number())

            async_result = self.protocol.send_async(
                direct_channel.partner_state.address,
                direct_transfer,
            )

        else:
            async_result = AsyncResult()
            async_result.set(False)

        return async_result

    def start_mediated_transfer(self, token_address, amount, identifier,
                                target):
        # pylint: disable=too-many-locals

        async_result = AsyncResult()
        graph = self.token_to_channelgraph[token_address]

        available_routes = get_best_routes(
            graph,
            self.protocol.nodeaddresses_networkstatuses,
            self.address,
            target,
            amount,
            None,
        )

        if not available_routes:
            async_result.set(False)
            return async_result

        self.protocol.start_health_check(target)

        if identifier is None:
            identifier = create_default_identifier()

        route_state = RoutesState(available_routes)
        our_address = self.address
        block_number = self.get_block_number()

        transfer_state = LockedTransferState(
            identifier=identifier,
            amount=amount,
            token=token_address,
            initiator=self.address,
            target=target,
            expiration=None,
            hashlock=None,
            secret=None,
        )

        # Issue #489
        #
        # Raiden may fail after a state change using the random generator is
        # handled but right before the snapshot is taken. If that happens on
        # the next initialization when raiden is recovering and applying the
        # pending state changes a new secret will be generated and the
        # resulting events won't match, this breaks the architecture model,
        # since it's assumed the re-execution of a state change will always
        # produce the same events.
        #
        # TODO: Removed the secret generator from the InitiatorState and add
        # the secret into all state changes that require one, this way the
        # secret will be serialized with the state change and the recovery will
        # use the same /random/ secret.
        random_generator = RandomSecretGenerator()

        init_initiator = ActionInitInitiator(
            our_address=our_address,
            transfer=transfer_state,
            routes=route_state,
            random_generator=random_generator,
            block_number=block_number,
        )

        state_manager = StateManager(initiator.state_transition, None)
        self.state_machine_event_handler.log_and_dispatch(
            state_manager, init_initiator)

        # TODO: implement the network timeout raiden.config['msg_timeout'] and
        # cancel the current transfer if it hapens (issue #374)
        self.identifier_to_statemanagers[identifier].append(state_manager)
        self.identifier_to_results[identifier].append(async_result)

        return async_result

    def mediate_mediated_transfer(self, message):
        # pylint: disable=too-many-locals
        identifier = message.identifier
        amount = message.lock.amount
        target = message.target
        token = message.token
        graph = self.token_to_channelgraph[token]

        available_routes = get_best_routes(
            graph,
            self.protocol.nodeaddresses_networkstatuses,
            self.address,
            target,
            amount,
            message.sender,
        )

        from_channel = graph.partneraddress_to_channel[message.sender]
        from_route = channel_to_routestate(from_channel, message.sender)

        our_address = self.address
        from_transfer = lockedtransfer_from_message(message)
        route_state = RoutesState(available_routes)
        block_number = self.get_block_number()

        init_mediator = ActionInitMediator(
            our_address,
            from_transfer,
            route_state,
            from_route,
            block_number,
        )

        state_manager = StateManager(mediator.state_transition, None)

        self.state_machine_event_handler.log_and_dispatch(
            state_manager, init_mediator)

        self.identifier_to_statemanagers[identifier].append(state_manager)

    def target_mediated_transfer(self, message):
        graph = self.token_to_channelgraph[message.token]
        from_channel = graph.partneraddress_to_channel[message.sender]
        from_route = channel_to_routestate(from_channel, message.sender)

        from_transfer = lockedtransfer_from_message(message)
        our_address = self.address
        block_number = self.get_block_number()

        init_target = ActionInitTarget(
            our_address,
            from_route,
            from_transfer,
            block_number,
        )

        state_manager = StateManager(target_task.state_transition, None)
        self.state_machine_event_handler.log_and_dispatch(
            state_manager, init_target)

        identifier = message.identifier
        self.identifier_to_statemanagers[identifier].append(state_manager)
예제 #52
0
class QCProcessor(SimpleProcess):
    def __init__(self):
        self.event = Event()  # Synchronizes the thread
        self.timeout = 10

    def on_start(self):
        '''
        Process initialization
        '''
        self._thread = self._process.thread_manager.spawn(self.thread_loop)
        self._event_subscriber = EventSubscriber(
            event_type=OT.ResetQCEvent,
            callback=self.receive_event,
            auto_delete=True)  # TODO Correct event types
        self._event_subscriber.start()
        self.timeout = self.CFG.get_safe('endpoint.receive.timeout', 10)
        self.resource_registry = self.container.resource_registry
        self.event_queue = Queue()

    def on_quit(self):
        '''
        Stop and cleanup the thread
        '''
        self._event_subscriber.stop()
        self.suspend()

    def receive_event(self, event, *args, **kwargs):
        log.error("Adding event to the event queue")
        self.event_queue.put(event)

    def thread_loop(self):
        '''
        Asynchronous event-loop
        '''
        threading.current_thread().name = '%s-qc-processor' % self.id
        while not self.event.wait(1):
            try:
                self.qc_processing_loop()
            except:
                log.error("Error in QC Processing Loop", exc_info=True)
            try:
                self.event_processing_loop()
            except:
                log.error("Error in QC Event Loop", exc_info=True)

    def qc_processing_loop(self):
        '''
        Iterates through available data products and evaluates QC
        '''
        data_products, _ = self.container.resource_registry.find_resources(
            restype=RT.DataProduct, id_only=False)
        for data_product in data_products:
            # Get the reference designator
            try:
                rd = self.get_reference_designator(data_product._id)
            except BadRequest:
                continue
            parameters = self.get_parameters(data_product)
            # Create a mapping of inputs to QC
            qc_mapping = {}

            # Creates a dictionary { data_product_name : parameter_name }
            for p in parameters:
                if p.ooi_short_name:
                    sname = p.ooi_short_name
                    g = re.match(r'([a-zA-Z-_]+)(_L[0-9])', sname)
                    if g:
                        sname = g.groups()[0]
                    qc_mapping[sname] = p.name

            for p in parameters:
                # for each parameter, if the name ends in _qc run the qc
                if p.name.endswith('_qc'):
                    self.run_qc(data_product, rd, p, qc_mapping, parameters)

            # Break early if we can
            if self.event.is_set():
                break

    def event_processing_loop(self):
        '''
        Processes the events in the event queue
        '''
        log.error("Processing event queue")
        self.event_queue.put(StopIteration)
        for event in self.event_queue:
            log.error("My event's reference designator: %s", event.origin)

    def suspend(self):
        '''
        Stops the event loop
        '''
        self.event.set()
        self._thread.join(self.timeout)
        log.info("QC Thread Suspended")

    def get_reference_designator(self, data_product_id=''):
        '''
        Returns the reference designator for a data product if it has one
        '''
        # First try to get the parent data product
        data_product_ids, _ = self.resource_registry.find_objects(
            subject=data_product_id,
            predicate=PRED.hasDataProductParent,
            id_only=True)
        if data_product_ids:
            return self.get_reference_designator(data_product_ids[0])

        device_ids, _ = self.resource_registry.find_subjects(
            object=data_product_id,
            predicate=PRED.hasOutputProduct,
            subject_type=RT.InstrumentDevice,
            id_only=True)
        if not device_ids:
            raise BadRequest(
                "No instrument device associated with this data product")
        device_id = device_ids[0]

        sites, _ = self.resource_registry.find_subjects(
            object=device_id,
            predicate=PRED.hasDevice,
            subject_type=RT.InstrumentSite,
            id_only=False)
        if not sites:
            raise BadRequest("No site is associated with this data product")
        site = sites[0]
        rd = site.reference_designator
        return rd

    def calibrated_candidates(self, data_product, parameter, qc_mapping,
                              parameters):
        '''
        Returns a list of potential candidate parameter names to use as the input parameter
        '''

        # 1st Priority is *b_interp
        # 2nd Priority is *b_pd
        # 3rd Priority is input_name
        parameters = {p.name: p for p in parameters}

        dp_ident, alg, qc = parameter.ooi_short_name.split('_')
        input_name = qc_mapping[dp_ident]  # input_name is the third priority

        sname = parameters[
            input_name].ooi_short_name  # should be something like tempwat_l1

        interp = sname.lower() + 'b_interp'
        pd = sname.lower() + 'b_pd'

        print "1st priority:", interp  # 1st priority
        print "2nd priority:", pd  # 2nd priority
        print "3rd priority:", input_name  # 3rd priority

        if interp in parameters:
            return interp
        elif pd in parameters:
            return pd
        else:
            return input_name

    def run_qc(self, data_product, reference_designator, parameter, qc_mapping,
               parameters):
        '''
        Determines which algorithm the parameter should run, then evaluates the QC

        data_product         - Data Product Resource
        reference_designator - reference designator string
        parameter            - parameter context resource
        qc_mapping           - a dictionary of { data_product_name : parameter_name }
        '''

        # We key off of the OOI Short Name
        # DATAPRD_ALGRTHM_QC
        dp_ident, alg, qc = parameter.ooi_short_name.split('_')
        if dp_ident not in qc_mapping:
            return  # No input!
        input_name = self.calibrated_candidates(data_product, parameter,
                                                qc_mapping, parameters)

        try:
            doc = self.container.object_store.read_doc(reference_designator)
        except NotFound:
            return  # NO QC lookups found
        if dp_ident not in doc:
            log.critical("Data product %s not in doc", dp_ident)
            return  # No data product of this listing in the RD's entry
        # Lookup table has the rows for the QC inputs
        lookup_table = doc[dp_ident]

        # An instance of the coverage is loaded if we need to run an algorithm
        dataset_id = self.get_dataset(data_product)
        coverage = self.get_coverage(dataset_id)
        if not coverage.num_timesteps:  # No data = no qc
            coverage.close()
            return

        try:
            # Get the lookup table info then run
            if alg.lower() == 'glblrng':
                row = self.recent_row(lookup_table['global_range'])
                min_value = row['min_value']
                max_value = row['max_value']
                self.process_glblrng(coverage, parameter, input_name,
                                     min_value, max_value)

            elif alg.lower() == 'stuckvl':
                row = self.recent_row(lookup_table['stuck_value'])
                resolution = row['resolution']
                N = row['consecutive_values']
                self.process_stuck_value(coverage, parameter, input_name,
                                         resolution, N)

            elif alg.lower() == 'trndtst':
                row = self.recent_row(lookup_table['trend_test'])
                ord_n = row['polynomial_order']
                nstd = row['standard_deviation']
                self.process_trend_test(coverage, parameter, input_name, ord_n,
                                        nstd)

            elif alg.lower() == 'spketst':
                row = self.recent_row(lookup_table['spike_test'])
                acc = row['accuracy']
                N = row['range_multiplier']
                L = row['window_length']
                self.process_spike_test(coverage, parameter, input_name, acc,
                                        N, L)

            elif alg.lower() == "gradtst":
                row = self.recent_row(lookup_table["gradient_test"])
                ddatdx = row["ddatdx"]
                mindx = row["mindx"]
                startdat = row["startdat"]
                if isinstance(startdat, basestring) and not startdat:
                    startdat = np.nan
                if isinstance(mindx, basestring) and not mindx:
                    mindx = np.nan
                toldat = row["toldat"]
                self.process_gradient_test(coverage, parameter, input_name,
                                           ddatdx, mindx, startdat, toldat)

            elif alg.lower() == 'loclrng':
                row = self.recent_row(lookup_table["local_range"])
                table = row['table']
                dims = []
                datlimz = []
                for key in table.iterkeys():
                    # Skip the datlims
                    if 'datlim' in key:
                        continue
                    dims.append(key)
                    datlimz.append(table[key])

                datlimz = np.column_stack(datlimz)
                datlim = np.column_stack([table['datlim1'], table['datlim2']])
                self.process_local_range_test(coverage, parameter, input_name,
                                              datlim, datlimz, dims)

        except KeyError:  # No lookup table
            self.set_error(coverage, parameter)

        finally:
            coverage.close()

    def set_error(self, coverage, parameter):
        log.error("setting coverage parameter %s to -99", parameter.name)

    def get_parameter_values(self, coverage, name):
        array = coverage.get_parameter_values(
            [name], fill_empty_params=True).get_data()[name]
        return array

    def process_glblrng(self, coverage, parameter, input_name, min_value,
                        max_value):
        '''
        Evaluates the QC for global range for all data values that equal -88 (not yet evaluated)
        '''
        log.error("input name: %s", input_name)
        log.info("Num timesteps: %s", coverage.num_timesteps)

        # Get all of the QC values, and find where -88 is set (uninitialized)
        qc_array = self.get_parameter_values(coverage, parameter.name)
        indexes = np.where(qc_array == -88)[0]

        # Now build a variable, but I need to keep track of the time where the data goes
        time_array = self.get_parameter_values(
            coverage, coverage.temporal_parameter_name)[indexes]
        value_array = self.get_parameter_values(coverage, input_name)[indexes]

        from ion_functions.qc.qc_functions import dataqc_globalrangetest
        qc = dataqc_globalrangetest(value_array, [min_value, max_value])
        return_dictionary = {
            coverage.temporal_parameter_name: time_array,
            parameter.name: qc
        }

    def process_stuck_value(self, coverage, parameter, input_name, resolution,
                            N):
        '''
        Evaluates the QC for stuck value for all data values that equal -88 (not yet evaluated)
        '''
        # Get al of the QC values and find out where -88 is set
        qc_array = self.get_parameter_values(coverage, parameter.name)
        indexes = np.where(qc_array == -88)[0]

        # Horribly inefficient...
        from ion_functions.qc.qc_functions import dataqc_stuckvaluetest_wrapper
        value_array = self.get_parameter_values(coverage, input_name)[indexes]
        qc_array = dataqc_stuckvaluetest_wrapper(value_array, resolution, N)
        qc_array = qc_array[indexes]
        time_array = self.get_parameter_values(
            coverage, coverage.temporal_parameter_name)[indexes]

        return_dictionary = {
            coverage.temporal_parameter_name: time_array,
            parameter.name: qc_array
        }

    def process_trend_test(self, coverage, parameter, input_name, ord_n, nstd):
        '''
        Evaluates the QC for trend test for all data values that equal -88 (not yet evaluated)
        '''
        # Get al of the QC values and find out where -88 is set
        qc_array = self.get_parameter_values(coverage, parameter.name)
        indexes = np.where(qc_array == -88)[0]

        from ion_functions.qc.qc_functions import dataqc_polytrendtest_wrapper
        time_array = self.get_parameter_values(
            coverage, coverage.temporal_parameter_name)[indexes]
        value_array = self.get_parameter_values(coverage, input_name)[indexes]

        qc_array = dataqc_polytrendtest_wrapper(value_array, time_array, ord_n,
                                                nstd)
        qc_array = qc_array[indexes]
        return_dictionary = {
            coverage.temporal_parameter_name: time_array,
            parameter.name: qc_array
        }

    def process_spike_test(self, coverage, parameter, input_name, acc, N, L):
        '''
        Evaluates the QC for spike test for all data values that equal -88 (not yet evaluated)
        '''
        # Get al of the QC values and find out where -88 is set
        qc_array = self.get_parameter_values(coverage, parameter.name)
        indexes = np.where(qc_array == -88)[0]

        from ion_functions.qc.qc_functions import dataqc_spiketest_wrapper
        value_array = self.get_parameter_values(coverage, input_name)[indexes]
        qc_array = dataqc_spiketest_wrapper(value_array, acc, N, L)
        qc_array = qc_array[indexes]
        time_array = self.get_parameter_values(
            coverage, coverage.temporal_parameter_name)[indexes]
        return_dictionary = {
            coverage.temporal_parameter_name: time_array,
            parameter.name: qc_array
        }

    def process_gradient_test(self, coverage, parameter, input_name, ddatdx,
                              mindx, startdat, toldat):
        qc_array = self.get_parameter_values(coverage, parameter.name)
        indexes = np.where(qc_array == -88)[0]

        from ion_functions.qc.qc_functions import dataqc_gradienttest_wrapper
        value_array = self.get_parameter_values(coverage, input_name)[indexes]
        time_array = self.get_parameter_values(
            coverage, coverage.temporal_parameter_name)[indexes]

        qc_array = dataqc_gradienttest_wrapper(value_array, time_array, ddatdx,
                                               mindx, startdat, toldat)

        return_dictionary = {
            coverage.temporal_parameter_name: time_array[indexes],
            parameter.name: qc_array[indexes]
        }

    def process_local_range_test(self, coverage, parameter, input_name, datlim,
                                 datlimz, dims):
        return  # Not ready
        qc_array = self.get_parameter_values(coverage, parameter.name)
        indexes = np.where(qc_array == -88)[0]

        from ion_functions.qc.qc_functions import dataqc_localrangetest_wrapper
        # dat
        value_array = self.get_parameter_values(coverage, input_name)[indexes]
        time_array = self.get_parameter_values(
            coverage, coverage.temporal_parameter_name)[indexes]

        # datlim is an argument and comes from the lookup table
        # datlimz is an argument and comes from the lookup table
        # dims is an argument and is created using the column headings
        # pval_callback, well as for that...
        # TODO: slice_ is the window of the site data product, but for
        # now we'll just use a global slice
        slice_ = slice(None)

        def parameter_callback(param_name):
            return coverage.get_parameter_values(param_name, slice_)

        qc_array = dataqc_localrangetest_wrapper(value_array, datlim, datlimz,
                                                 dims, parameter_callback)
        return_dictionary = {
            coverage.temporal_parameter_name: time_array[indexes],
            parameter.name: qc_array[indexes]
        }
        log.error("Here's what it would look like\n%s", return_dictionary)

    def get_dataset(self, data_product):
        dataset_ids, _ = self.resource_registry.find_objects(data_product,
                                                             PRED.hasDataset,
                                                             id_only=True)
        if not dataset_ids:
            raise BadRequest("No Dataset")
        dataset_id = dataset_ids[0]
        return dataset_id

    def get_coverage(self, dataset_id):
        cov = DatasetManagementService._get_coverage(dataset_id, mode='r+')
        return cov

    def recent_row(self, rows):
        '''
        Determines the most recent data based on the timestamp
        '''
        most_recent = None
        ts = 0
        for row in rows:
            if row['ts_created'] > ts:
                most_recent = row
                ts = row['ts_created']
        return most_recent

    def get_parameters(self, data_product):
        '''
        Returns the relevant parameter contexts of the data product
        '''

        # DataProduct -> StreamDefinition
        stream_defs, _ = self.resource_registry.find_objects(
            data_product._id, PRED.hasStreamDefinition, id_only=False)
        stream_def = stream_defs[0]

        # StreamDefinition -> ParameterDictionary
        pdict_ids, _ = self.resource_registry.find_objects(
            stream_def._id, PRED.hasParameterDictionary, id_only=True)
        pdict_id = pdict_ids[0]

        # ParameterDictionary -> ParameterContext
        pctxts, _ = self.resource_registry.find_objects(
            pdict_id, PRED.hasParameterContext, id_only=False)
        relevant = [
            ctx for ctx in pctxts if not stream_def.available_fields or (
                stream_def.available_fields
                and ctx.name in stream_def.available_fields)
        ]
        return relevant
예제 #53
0
class Socket(object):
    """
    Virtual Socket implementation, checks heartbeats, writes to local queues
    for message passing, holds the Namespace objects, dispatches de packets
    to the underlying namespaces.

    This is the abstraction on top of the different transports. It's like
    if you used a WebSocket only...
    """

    STATE_CONNECTING = "CONNECTING"
    STATE_CONNECTED = "CONNECTED"
    STATE_DISCONNECTING = "DISCONNECTING"
    STATE_DISCONNECTED = "DISCONNECTED"

    GLOBAL_NS = ''
    """Use this to be explicit when specifying a Global Namespace (an endpoint
    with no name, not '/chat' or anything."""

    json_loads = staticmethod(default_json_loads)
    json_dumps = staticmethod(default_json_dumps)

    def __init__(self, server, config, error_handler=None):
        self.server = weakref.proxy(server)
        self.sessid = str(random.random())[2:]
        self.session = {}  # the session dict, for general developer usage
        self.client_queue = Queue()  # queue for messages to client
        self.server_queue = Queue()  # queue for messages to server
        self.hits = 0
        self.heartbeats = 0
        self.timeout = Event()
        self.wsgi_app_greenlet = None
        self.state = "NEW"
        self.connection_established = False
        self.ack_callbacks = {}
        self.ack_counter = 0
        self.request = None
        self.environ = None
        self.namespaces = {}
        self.active_ns = {}  # Namespace sessions that were instantiated
        self.jobs = []
        self.error_handler = default_error_handler
        self.config = config
        if error_handler is not None:
            self.error_handler = error_handler

    def _set_namespaces(self, namespaces):
        """This is a mapping (dict) of the different '/namespaces' to their
        BaseNamespace object derivative.

        This is called by socketio_manage()."""
        self.namespaces = namespaces

    def _set_request(self, request):
        """Saves the request object for future use by the different Namespaces.

        This is called by socketio_manage().
        """
        self.request = request

    def _set_environ(self, environ):
        """Save the WSGI environ, for future use.

        This is called by socketio_manage().
        """
        self.environ = environ

    def _set_error_handler(self, error_handler):
        """Changes the default error_handler function to the one specified

        This is called by socketio_manage().
        """
        self.error_handler = error_handler

    def _set_json_loads(self, json_loads):
        """Change the default JSON decoder.

        This should be a callable that accepts a single string, and returns
        a well-formed object.
        """
        self.json_loads = json_loads

    def _set_json_dumps(self, json_dumps):
        """Change the default JSON decoder.

        This should be a callable that accepts a single string, and returns
        a well-formed object.
        """
        self.json_dumps = json_dumps

    def _get_next_msgid(self):
        """This retrieves the next value for the 'id' field when sending
        an 'event' or 'message' or 'json' that asks the remote client
        to 'ack' back, so that we trigger the local callback.
        """
        self.ack_counter += 1
        return self.ack_counter

    def _save_ack_callback(self, msgid, callback):
        """Keep a reference of the callback on this socket."""
        if msgid in self.ack_callbacks:
            return False
        self.ack_callbacks[msgid] = callback

    def _pop_ack_callback(self, msgid):
        """Fetch the callback for a given msgid, if it exists, otherwise,
        return None"""
        if msgid not in self.ack_callbacks:
            return None
        return self.ack_callbacks.pop(msgid)

    def __str__(self):
        result = ['sessid=%r' % self.sessid]
        if self.state == self.STATE_CONNECTED:
            result.append('connected')
        if self.client_queue.qsize():
            result.append('client_queue[%s]' % self.client_queue.qsize())
        if self.server_queue.qsize():
            result.append('server_queue[%s]' % self.server_queue.qsize())
        if self.hits:
            result.append('hits=%s' % self.hits)
        if self.heartbeats:
            result.append('heartbeats=%s' % self.heartbeats)

        return ' '.join(result)

    def __getitem__(self, key):
        """This will get the nested Namespace using its '/chat' reference.

        Using this, you can go from one Namespace to the other (to emit, add
        ACLs, etc..) with:

          adminnamespace.socket['/chat'].add_acl_method('kick-ban')

        """
        return self.active_ns[key]

    def __hasitem__(self, key):
        """Verifies if the namespace is active (was initialized)"""
        return key in self.active_ns

    @property
    def connected(self):
        """Returns whether the state is CONNECTED or not."""
        return self.state == self.STATE_CONNECTED

    def incr_hits(self):
        self.hits += 1

    def heartbeat(self):
        """This makes the heart beat for another X seconds.  Call this when
        you get a heartbeat packet in.

        This clear the heartbeat disconnect timeout (resets for X seconds).
        """
        self.timeout.set()

    def kill(self, detach=False):
        """This function must/will be called when a socket is to be completely
        shut down, closed by connection timeout, connection error or explicit
        disconnection from the client.

        It will call all of the Namespace's
        :meth:`~socketio.namespace.BaseNamespace.disconnect` methods
        so that you can shut-down things properly.

        """
        # Clear out the callbacks
        self.ack_callbacks = {}
        if self.connected:
            self.state = self.STATE_DISCONNECTING
            self.server_queue.put_nowait(None)
            self.client_queue.put_nowait(None)
            if len(self.active_ns) > 0:
                log.debug("Calling disconnect() on %s" % self)
                self.disconnect()

        if detach:
            self.detach()

        gevent.killall(self.jobs)

    def detach(self):
        """Detach this socket from the server. This should be done in
        conjunction with kill(), once all the jobs are dead, detach the
        socket for garbage collection."""

        log.debug("Removing %s from server sockets" % self)
        if self.sessid in self.server.sockets:
            self.server.sockets.pop(self.sessid)

    def put_server_msg(self, msg):
        """Writes to the server's pipe, to end up in in the Namespaces"""
        self.heartbeat()
        self.server_queue.put_nowait(msg)

    def put_client_msg(self, msg):
        """Writes to the client's pipe, to end up in the browser"""
        self.client_queue.put_nowait(msg)

    def get_client_msg(self, **kwargs):
        """Grab a message to send it to the browser"""
        return self.client_queue.get(**kwargs)

    def get_server_msg(self, **kwargs):
        """Grab a message, to process it by the server and dispatch calls
        """
        return self.server_queue.get(**kwargs)

    def get_multiple_client_msgs(self, **kwargs):
        """Get multiple messages, in case we're going through the various
        XHR-polling methods, on which we can pack more than one message if the
        rate is high, and encode the payload for the HTTP channel."""
        client_queue = self.client_queue
        msgs = [client_queue.get(**kwargs)]
        while client_queue.qsize():
            msgs.append(client_queue.get())
        return msgs

    def error(self, error_name, error_message, endpoint=None, msg_id=None,
              quiet=False):
        """Send an error to the user, using the custom or default
        ErrorHandler configured on the [TODO: Revise this] Socket/Handler
        object.

        :param error_name: is a simple string, for easy association on
                           the client side

        :param error_message: is a human readable message, the user
                              will eventually see

        :param endpoint: set this if you have a message specific to an
                         end point

        :param msg_id: set this if your error is relative to a
                       specific message

        :param quiet: way to make the error handler quiet. Specific to
                      the handler.  The default handler will only log,
                      with quiet.
        """
        handler = self.error_handler
        return handler(
            self, error_name, error_message, endpoint, msg_id, quiet)

    # User facing low-level function
    def disconnect(self, silent=False):
        """Calling this method will call the
        :meth:`~socketio.namespace.BaseNamespace.disconnect` method on
        all the active Namespaces that were open, killing all their
        jobs and sending 'disconnect' packets for each of them.

        Normally, the Global namespace (endpoint = '') has special meaning,
        as it represents the whole connection,

        :param silent: when True, pass on the ``silent`` flag to the Namespace
                       :meth:`~socketio.namespace.BaseNamespace.disconnect`
                       calls.
        """
        for ns_name, ns in list(self.active_ns.iteritems()):
            ns.recv_disconnect()

    def remove_namespace(self, namespace):
        """This removes a Namespace object from the socket.

        This is usually called by
        :meth:`~socketio.namespace.BaseNamespace.disconnect`.

        """
        if namespace in self.active_ns:
            del self.active_ns[namespace]

        if len(self.active_ns) == 0 and self.connected:
            self.kill(detach=True)

    def send_packet(self, pkt):
        """Low-level interface to queue a packet on the wire (encoded as wire
        protocol"""
        self.put_client_msg(packet.encode(pkt, self.json_dumps))

    def spawn(self, fn, *args, **kwargs):
        """Spawn a new Greenlet, attached to this Socket instance.

        It will be monitored by the "watcher" method
        """

        log.debug("Spawning sub-Socket Greenlet: %s" % fn.__name__)
        job = gevent.spawn(fn, *args, **kwargs)
        self.jobs.append(job)
        return job

    def _receiver_loop(self):
        """This is the loop that takes messages from the queue for the server
        to consume, decodes them and dispatches them.

        It is the main loop for a socket.  We join on this process before
        returning control to the web framework.

        This process is not tracked by the socket itself, it is not going
        to be killed by the ``gevent.killall(socket.jobs)``, so it must
        exit gracefully itself.
        """

        while True:
            rawdata = self.get_server_msg()

            if not rawdata:
                continue  # or close the connection ?
            try:
                pkt = packet.decode(rawdata, self.json_loads)
            except (ValueError, KeyError, Exception), e:
                self.error('invalid_packet',
                    "There was a decoding error when dealing with packet "
                    "with event: %s... (%s)" % (rawdata[:20], e))
                continue

            if pkt['type'] == 'heartbeat':
                # This is already dealth with in put_server_msg() when
                # any incoming raw data arrives.
                continue

            if pkt['type'] == 'disconnect' and pkt['endpoint'] == '':
                # On global namespace, we kill everything.
                self.kill(detach=True)
                continue

            endpoint = pkt['endpoint']

            if endpoint not in self.namespaces:
                self.error("no_such_namespace",
                    "The endpoint you tried to connect to "
                    "doesn't exist: %s" % endpoint, endpoint=endpoint)
                continue
            elif endpoint in self.active_ns:
                pkt_ns = self.active_ns[endpoint]
            else:
                new_ns_class = self.namespaces[endpoint]
                pkt_ns = new_ns_class(self.environ, endpoint,
                                        request=self.request)
                # This calls initialize() on all the classes and mixins, etc..
                # in the order of the MRO
                for cls in type(pkt_ns).__mro__:
                    if hasattr(cls, 'initialize'):
                        cls.initialize(pkt_ns)  # use this instead of __init__,
                                                # for less confusion

                self.active_ns[endpoint] = pkt_ns

            retval = pkt_ns.process_packet(pkt)

            # Has the client requested an 'ack' with the reply parameters ?
            if pkt.get('ack') == "data" and pkt.get('id'):
                if type(retval) is tuple:
                    args = list(retval)
                else:
                    args = [retval]
                returning_ack = dict(type='ack', ackId=pkt['id'],
                                     args=retval,
                                     endpoint=pkt.get('endpoint', ''))
                self.send_packet(returning_ack)

            # Now, are we still connected ?
            if not self.connected:
                self.kill(detach=True)  # ?? what,s the best clean-up
                                        # when its not a
                                        # user-initiated disconnect
                return
예제 #54
0
class UserAddressManager:
    """ Matrix user <-> eth address mapping and user / address reachability helper.

    In Raiden the smallest unit of addressability is a node with an associated Ethereum address.
    In Matrix it's a user. Matrix users are (at the moment) bound to a specific homeserver.
    Since we want to provide resiliency against unavailable homeservers a single Raiden node with
    a single Ethereum address can be in control over multiple Matrix users on multiple homeservers.

    Therefore we need to perform a many-to-one mapping of Matrix users to Ethereum addresses.
    Each Matrix user has a presence state (ONLINE, OFFLINE).
    One of the preconditions of running a Raiden node is that there can always only be one node
    online for a particular address at a time.
    That means we can synthesize the reachability of an address from the user presence states.

    This helper internally tracks both the user presence and address reachability for addresses
    that have been marked as being 'interesting' (by calling the `.add_address()` method).
    Additionally it provides the option of passing callbacks that will be notified when
    presence / reachability change.
    """

    def __init__(
        self,
        client: GMatrixClient,
        displayname_cache: DisplayNameCache,
        address_reachability_changed_callback: Callable[[Address, AddressReachability], None],
        user_presence_changed_callback: Optional[Callable[[User, UserPresence], None]] = None,
        _log_context: Optional[Dict[str, Any]] = None,
    ) -> None:
        self._client = client
        self._displayname_cache = displayname_cache
        self._address_reachability_changed_callback = address_reachability_changed_callback
        self._user_presence_changed_callback = user_presence_changed_callback
        self._stop_event = Event()

        self._reset_state()

        self._log_context = _log_context
        self._log = None
        self._listener_id: Optional[UUID] = None

    def start(self) -> None:
        """ Start listening for presence updates.

        Should be called before ``.login()`` is called on the underlying client. """
        assert self._listener_id is None, "UserAddressManager.start() called twice"
        self._stop_event.clear()
        self._listener_id = self._client.add_presence_listener(self._presence_listener)

    def stop(self) -> None:
        """ Stop listening on presence updates. """
        assert self._listener_id is not None, "UserAddressManager.stop() called before start"
        self._stop_event.set()
        self._client.remove_presence_listener(self._listener_id)
        self._listener_id = None
        self._log = None
        self._reset_state()

    @property
    def known_addresses(self) -> Set[Address]:
        """ Return all addresses we keep track of """
        # This must return a copy of the current keys, because the container
        # may be modified while these values are used. Issue: #5240
        return set(self._address_to_userids)

    def is_address_known(self, address: Address) -> bool:
        """ Is the given ``address`` reachability being monitored? """
        return address in self._address_to_userids

    def add_address(self, address: Address) -> None:
        """ Add ``address`` to the known addresses that are being observed for reachability. """
        # Since _address_to_userids is a defaultdict accessing the key creates the entry
        _ = self._address_to_userids[address]

    def add_userid_for_address(self, address: Address, user_id: str) -> None:
        """ Add a ``user_id`` for the given ``address``.

        Implicitly adds the address if it was unknown before.
        """
        self._address_to_userids[address].add(user_id)

    def add_userids_for_address(self, address: Address, user_ids: Iterable[str]) -> None:
        """ Add multiple ``user_ids`` for the given ``address``.

        Implicitly adds any addresses if they were unknown before.
        """
        self._address_to_userids[address].update(user_ids)

    def get_userids_for_address(self, address: Address) -> Set[str]:
        """ Return all known user ids for the given ``address``. """
        if not self.is_address_known(address):
            return set()
        return self._address_to_userids[address]

    def get_userid_presence(self, user_id: str) -> UserPresence:
        """ Return the current presence state of ``user_id``. """
        return self._userid_to_presence.get(user_id, UserPresence.UNKNOWN)

    def get_address_reachability(self, address: Address) -> AddressReachability:
        """ Return the current reachability state for ``address``. """
        return self._address_to_reachability.get(address, AddressReachability.UNKNOWN)

    def force_user_presence(self, user: User, presence: UserPresence) -> None:
        """ Forcibly set the ``user`` presence to ``presence``.

        This method is only provided to cover an edge case in our use of the Matrix protocol and
        should **not** generally be used.
        """
        self._userid_to_presence[user.user_id] = presence

    def populate_userids_for_address(self, address: Address, force: bool = False) -> None:
        """ Populate known user ids for the given ``address`` from the server directory.

        If ``force`` is ``True`` perform the directory search even if there
        already are known users.
        """
        if force or not self.get_userids_for_address(address):
            self.add_userids_for_address(
                address,
                (
                    user.user_id
                    for user in self._client.search_user_directory(to_normalized_address(address))
                    if self._validate_userid_signature(user)
                ),
            )

    def track_address_presence(self, address: Address, user_ids: Set[str]) -> None:
        """
        Update synthesized address presence state from cached user presence states.

        Triggers callback (if any) in case the state has changed.

        This method is only provided to cover an edge case in our use of the Matrix protocol and
        should **not** generally be used.
        """
        self.add_userids_for_address(address, user_ids)
        userids_to_presence = {}
        for uid in user_ids:
            presence = self._fetch_user_presence(uid)
            userids_to_presence[uid] = presence
            self._set_user_presence(uid, presence)

        log.debug(
            "Fetched user presences",
            address=to_checksum_address(address),
            userids_to_presence=userids_to_presence,
        )

        self._maybe_address_reachability_changed(address)

    def _maybe_address_reachability_changed(self, address: Address) -> None:
        # A Raiden node may have multiple Matrix users, this happens when
        # Raiden roams from a Matrix server to another. This loop goes over all
        # these users and uses the "best" presence. IOW, if there is a single
        # Matrix user that is reachable, then the Raiden node is considered
        # reachable.
        userids = self._address_to_userids[address].copy()
        composite_presence = {self._userid_to_presence.get(uid) for uid in userids}

        new_presence = UserPresence.UNKNOWN
        for presence in UserPresence.__members__.values():
            if presence in composite_presence:
                new_presence = presence
                break

        new_address_reachability = USER_PRESENCE_TO_ADDRESS_REACHABILITY[new_presence]

        prev_addresss_reachability = self.get_address_reachability(address)
        if new_address_reachability == prev_addresss_reachability:
            return

        self.log.debug(
            "Changing address reachability state",
            address=to_checksum_address(address),
            prev_state=prev_addresss_reachability,
            state=new_address_reachability,
        )

        self._address_to_reachability[address] = new_address_reachability
        self._address_reachability_changed_callback(address, new_address_reachability)

    def _presence_listener(self, event: Dict[str, Any]) -> None:
        """
        Update cached user presence state from Matrix presence events.

        Due to the possibility of nodes using accounts on multiple homeservers a composite
        address state is synthesised from the cached individual user presence states.
        """
        if self._stop_event.ready():
            return

        user_id = event["sender"]

        if event["type"] != "m.presence" or user_id == self._user_id:
            return

        address = address_from_userid(user_id)

        # Not a user we've whitelisted, skip. This needs to be on the top of
        # the function so that we don't request they displayname of users that
        # are not important for the node. The presence is updated for every
        # user on the first sync, since every Raiden node is a member of a
        # broadcast room. This can result in thousands requests to the Matrix
        # server in the first sync which will lead to slow startup times and
        # presence problems.
        if address is None or not self.is_address_known(address):
            return

        user = self._user_from_id(user_id, event["content"].get("displayname"))

        if not user:
            return

        address = self._validate_userid_signature(user)
        if not address:
            return

        self._displayname_cache.warm_users([user])

        self.add_userid_for_address(address, user_id)

        new_state = UserPresence(event["content"]["presence"])

        self._set_user_presence(user_id, new_state)
        self._maybe_address_reachability_changed(address)

    def _reset_state(self) -> None:
        self._address_to_userids: Dict[Address, Set[str]] = defaultdict(set)
        self._address_to_reachability: Dict[Address, AddressReachability] = dict()
        self._userid_to_presence: Dict[str, UserPresence] = dict()

    @property
    def _user_id(self) -> str:
        user_id = getattr(self._client, "user_id", None)
        assert user_id, f"{self.__class__.__name__}._user_id accessed before client login"
        return user_id

    def _user_from_id(self, user_id: str, display_name: Optional[str] = None) -> Optional[User]:
        try:
            return User(self._client.api, user_id, display_name)
        except ValueError:
            log.error("Matrix server returned an invalid user_id.")
        return None

    def _fetch_user_presence(self, user_id: str) -> UserPresence:
        try:
            presence = UserPresence(self._client.get_user_presence(user_id))
        except MatrixRequestError:
            # The following exception will be raised if the local user and the
            # target user do not have a shared room:
            #
            #   MatrixRequestError: 403:
            #   {"errcode":"M_FORBIDDEN","error":"You are not allowed to see their presence."}
            presence = UserPresence.UNKNOWN
            log.exception("Could not fetch user presence")

        return presence

    def _set_user_presence(self, user_id: str, presence: UserPresence) -> None:
        user = self._user_from_id(user_id)
        if not user:
            return

        old_presence = self._userid_to_presence.get(user_id)
        if old_presence != presence:
            self._userid_to_presence[user_id] = presence
            self.log.debug(
                "Changing user presence state",
                user_id=user_id,
                prev_state=old_presence,
                state=presence,
            )
            if self._user_presence_changed_callback:
                self._user_presence_changed_callback(user, presence)

    @staticmethod
    def _validate_userid_signature(user: User) -> Optional[Address]:
        return validate_userid_signature(user)

    @property
    def log(self) -> BoundLoggerLazyProxy:
        if self._log:
            return self._log

        context = self._log_context or {}

        # Only cache the logger once the user_id becomes available
        if hasattr(self._client, "user_id"):
            context["current_user"] = self._user_id
            context["node"] = node_address_from_userid(self._user_id)

            bound_log = log.bind(**context)
            self._log = bound_log
            return bound_log

        # Apply  the `_log_context` even if the user_id is not yet available
        return log.bind(**context)
예제 #55
0
class PSComm(object):
    def __init__(self, nodename, hostname, port, adj):
        self.nodename = nodename
        self.hostname = hostname
        self.port = port
        self.adj = adj
        self.server_greenlet = None
        self.subs_table = STable()
        self.table_available = Event()
        self.table_available.set()
        self.fetched_publications = []

    def req_handler(self, sock, clientaddress):
        msg = sock.recv(1024)
        try:
            msg = json.loads(msg)
        except ValueError as e:
            print('ERRO: MENSAGEM INVALIDA: {}'.format(e.message))
            print('Recebido: {}'.format(msg))
            msg = {}
        sock.close()

        if 'from' not in msg:
            print('ERRO: MENSAGEM SEM IDENTIFICACAO')
            return

        sender = msg['from']

        if 'subscription' in msg:
            subs = msg['subscription']
            self.table_available.wait()
            self.table_available.clear()
            updated = self.subs_table.update_table(subs['item'],
                                                   subs['subscriber'],
                                                   subs['next_node'],
                                                   subs['hops'])
            self.table_available.set()
            if updated:
                self.transmit_subscription(sender, subs)

        if 'publish' in msg:
            item_name = msg['publish']['item_name']
            content = msg['publish']['content']
            next_hops = self.subs_table.get_interested_adj(item_name)
            if self.nodename in next_hops:
                self.fetched_publications.append(
                    (item_name, content, str(datetime.datetime.now())))
            else:
                self.transmit_publications(sender, item_name, content,
                                           next_hops)

    def start_server(self):
        server = StreamServer((self.hostname, self.port), self.req_handler)
        server.serve_forever()

    def receive_req(self):
        pass

    def send_req(self, hostname, port, msg):
        try:
            sock = socket.create_connection((hostname, port))
        except socket.error:
            print('ERRO: conexao recusada para {}:{}'.format(hostname, port))
            return False
        sock.send(msg)

    def transmit_subscription(self, sender, subs):
        subs['hops'] += 1
        subs['next_node'] = self.nodename
        for adj_node in self.adj:
            nodename, hostname, port = adj_node
            if nodename != sender:
                self.send_req(
                    hostname, port,
                    json.dumps({
                        'from': self.nodename,
                        'subscription': subs
                    }))

    def transmit_publications(self, sender, item_name, content, next_hops):
        already_sent = []
        for node_id in next_hops:
            if node_id == sender or node_id in already_sent:
                continue
            for adj in self.adj:
                if node_id == adj[0]:
                    self.send_req(
                        adj[1], adj[2],
                        json.dumps({
                            'from': self.nodename,
                            'publish': {
                                'item_name': item_name,
                                'content': content
                            }
                        }))
                    already_sent.append(node_id)

    def start(self):
        self.server_greenlet = gevent.spawn(self.start_server)
        while True:
            gevent.sleep(1)
예제 #56
0
파일: echo_node.py 프로젝트: vnblr/raiden
class EchoNode:
    def __init__(self, api, token_address):
        assert isinstance(api, RaidenAPI)
        self.ready = Event()

        self.api = api
        self.token_address = token_address

        existing_channels = self.api.get_channel_list(
            api.raiden.default_registry.address,
            self.token_address,
        )

        open_channels = [
            channel_state for channel_state in existing_channels
            if channel.get_status(channel_state) == CHANNEL_STATE_OPENED
        ]

        if len(open_channels) == 0:
            token = self.api.raiden.chain.token(self.token_address)
            if not token.balance_of(self.api.raiden.address) > 0:
                raise ValueError(
                    'not enough funds for echo node %s for token %s' % (
                        pex(self.api.raiden.address),
                        pex(self.token_address),
                    ))
            self.api.connect_token_network(
                self.token_address,
                token.balance_of(self.api.raiden.address),
                initial_channel_target=10,
                joinable_funds_target=.5,
            )

        self.last_poll_block = self.api.raiden.get_block_number()
        self.received_transfers = Queue()
        self.stop_signal = None  # used to signal REMOVE_CALLBACK and stop echo_workers
        self.greenlets = list()
        self.lock = BoundedSemaphore()
        self.seen_transfers = deque(list(), TRANSFER_MEMORY)
        self.num_handled_transfers = 0
        self.lottery_pool = Queue()
        # register ourselves with the raiden alarm task
        self.api.raiden.alarm.register_callback(self.echo_node_alarm_callback)
        self.echo_worker_greenlet = gevent.spawn(self.echo_worker)

    def echo_node_alarm_callback(self, block_number):
        """ This can be registered with the raiden AlarmTask.
        If `EchoNode.stop()` is called, it will give the return signal to be removed from
        the AlarmTask callbacks.
        """
        if not self.ready.is_set():
            self.ready.set()
        log.debug('echo_node callback', block_number=block_number)
        if self.stop_signal is not None:
            return REMOVE_CALLBACK
        else:
            self.greenlets.append(gevent.spawn(self.poll_all_received_events))
            return True

    def poll_all_received_events(self):
        """ This will be triggered once for each `echo_node_alarm_callback`.
        It polls all channels for `EventTransferReceivedSuccess` events,
        adds all new events to the `self.received_transfers` queue and
        respawns `self.echo_node_worker`, if it died. """

        locked = False
        try:
            with Timeout(10):
                locked = self.lock.acquire(blocking=False)
                if not locked:
                    return
                else:
                    channels = self.api.get_channel_list(
                        registry_address=self.api.raiden.default_registry.
                        address,
                        token_address=self.token_address,
                    )
                    received_transfers = list()
                    for channel_state in channels:
                        channel_events = self.api.get_channel_events(
                            channel_state.identifier,
                            self.last_poll_block,
                        )
                        received_transfers.extend([
                            event for event in channel_events
                            if event['event'] == 'EventTransferReceivedSuccess'
                        ])
                    for event in received_transfers:
                        transfer = event.copy()
                        transfer.pop('block_number')
                        self.received_transfers.put(transfer)
                    # set last_poll_block after events are enqueued (timeout safe)
                    if received_transfers:
                        self.last_poll_block = max(
                            event['block_number']
                            for event in received_transfers)
                    # increase last_poll_block if the blockchain proceeded
                    delta_blocks = self.api.raiden.get_block_number(
                    ) - self.last_poll_block
                    if delta_blocks > 1:
                        self.last_poll_block += 1

                    if not self.echo_worker_greenlet.started:
                        log.debug(
                            'restarting echo_worker_greenlet',
                            dead=self.echo_worker_greenlet.dead,
                            successful=self.echo_worker_greenlet.successful(),
                            exception=self.echo_worker_greenlet.exception,
                        )
                        self.echo_worker_greenlet = gevent.spawn(
                            self.echo_worker)
        except Timeout:
            log.info('timeout while polling for events')
        finally:
            if locked:
                self.lock.release()

    def echo_worker(self):
        """ The `echo_worker` works through the `self.received_transfers` queue and spawns
        `self.on_transfer` greenlets for all not-yet-seen transfers. """
        log.debug('echo worker', qsize=self.received_transfers.qsize())
        while self.stop_signal is None:
            if self.received_transfers.qsize() > 0:
                transfer = self.received_transfers.get()
                if transfer in self.seen_transfers:
                    log.debug(
                        'duplicate transfer ignored',
                        initiator=pex(transfer['initiator']),
                        amount=transfer['amount'],
                        identifier=transfer['identifier'],
                    )
                else:
                    self.seen_transfers.append(transfer)
                    self.greenlets.append(
                        gevent.spawn(self.on_transfer, transfer))
            else:
                gevent.sleep(.5)

    def on_transfer(self, transfer):
        """ This handles the echo logic, as described in
        https://github.com/raiden-network/raiden/issues/651:

            - for transfers with an amount that satisfies `amount % 3 == 0`, it sends a transfer
            with an amount of `amount - 1` back to the initiator
            - for transfers with a "lucky number" amount `amount == 7` it does not send anything
            back immediately -- after having received "lucky number transfers" from 7 different
            addresses it sends a transfer with `amount = 49` to one randomly chosen one
            (from the 7 lucky addresses)
            - consecutive entries to the lucky lottery will receive the current pool size as the
            `echo_amount`
            - for all other transfers it sends a transfer with the same `amount` back to the
            initiator """
        echo_amount = 0
        if transfer['amount'] % 3 == 0:
            log.debug(
                'minus one transfer received',
                initiator=pex(transfer['initiator']),
                amount=transfer['amount'],
                identifier=transfer['identifier'],
            )
            echo_amount = transfer['amount'] - 1

        elif transfer['amount'] == 7:
            log.debug(
                'lucky number transfer received',
                initiator=pex(transfer['initiator']),
                amount=transfer['amount'],
                identifier=transfer['identifier'],
                poolsize=self.lottery_pool.qsize(),
            )

            # obtain a local copy of the pool
            pool = self.lottery_pool.copy()
            tickets = [pool.get() for _ in range(pool.qsize())]
            assert pool.empty()
            del pool

            if any(ticket['initiator'] == transfer['initiator']
                   for ticket in tickets):
                assert transfer not in tickets
                log.debug(
                    'duplicate lottery entry',
                    initiator=pex(transfer['initiator']),
                    identifier=transfer['identifier'],
                    poolsize=len(tickets),
                )
                # signal the poolsize to the participant
                echo_amount = len(tickets)

            # payout
            elif len(tickets) == 6:
                log.info('payout!')
                # reset the pool
                assert self.lottery_pool.qsize() == 6
                self.lottery_pool = Queue()
                # add new participant
                tickets.append(transfer)
                # choose the winner
                transfer = random.choice(tickets)
                echo_amount = 49
            else:
                self.lottery_pool.put(transfer)

        else:
            log.debug(
                'echo transfer received',
                initiator=pex(transfer['initiator']),
                amount=transfer['amount'],
                identifier=transfer['identifier'],
            )
            echo_amount = transfer['amount']

        if echo_amount:
            log.debug(
                'sending echo transfer',
                target=pex(transfer['initiator']),
                amount=echo_amount,
                orig_identifier=transfer['identifier'],
                echo_identifier=transfer['identifier'] + echo_amount,
                token_address=pex(self.token_address),
                num_handled_transfers=self.num_handled_transfers + 1,
            )

            self.api.transfer_and_wait(
                transfer.registry_address,
                self.token_address,
                echo_amount,
                transfer['initiator'],
                identifier=transfer['identifier'] + echo_amount,
            )
        self.num_handled_transfers += 1

    def stop(self):
        self.stop_signal = True
        self.greenlets.append(self.echo_worker_greenlet)
        gevent.wait(self.greenlets)
예제 #57
0
class Server(paramiko.ServerInterface):
    """Implements :mod:`paramiko.ServerInterface` to provide an
    embedded SSH2 server implementation.

    Start a `Server` with at least a :mod:`paramiko.Transport` object
    and a host private key.

    Any SSH2 client with public key or password authentication
    is allowed, only. Interactive shell requests are not accepted.

    Implemented:
    * Direct tcp-ip channels (tunneling)
    * SSH Agent forwarding on request
    * PTY requests
    * Exec requests (run a command on server)

    Not Implemented:
    * Interactive shell requests
    """

    def __init__(self, transport, host_key, fail_auth=False,
                 ssh_exception=False,
                 encoding='utf-8'):
        paramiko.ServerInterface.__init__(self)
        transport.load_server_moduli()
        transport.add_server_key(host_key)
        transport.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer)
        self.transport = transport
        self.event = Event()
        self.fail_auth = fail_auth
        self.ssh_exception = ssh_exception
        self.host_key = host_key
        self.encoding = encoding

    def check_channel_request(self, kind, chanid):
        return paramiko.OPEN_SUCCEEDED

    def check_auth_password(self, username, password):
        if self.fail_auth:
            return paramiko.AUTH_FAILED
        if self.ssh_exception:
            raise paramiko.SSHException()
        return paramiko.AUTH_SUCCESSFUL

    def check_auth_publickey(self, username, key):
        if self.fail_auth:
            return paramiko.AUTH_FAILED
        if self.ssh_exception:
            raise paramiko.SSHException()
        return paramiko.AUTH_SUCCESSFUL

    def get_allowed_auths(self, username):
        return 'password,publickey'

    def check_channel_shell_request(self, channel):
        return False

    def check_channel_pty_request(self, channel, term, width, height, pixelwidth,
                                  pixelheight, modes):
        return True

    def check_channel_direct_tcpip_request(self, chanid, origin, destination):
        logger.debug("Proxy connection %s -> %s requested", origin, destination,)
        extra = {'username' : self.transport.get_username()}
        logger.debug("Starting proxy connection %s -> %s",
                     origin, destination, extra=extra)
        try:
            tunnel = Tunneler(destination, self.transport, chanid)
            tunnel.start()
        except Exception as ex:
            logger.error("Error creating proxy connection to %s - %s",
                         destination, ex,)
            return paramiko.OPEN_FAILED_CONNECT_FAILED
        self.event.set()
        gevent.sleep()
        logger.debug("Proxy connection started")
        return paramiko.OPEN_SUCCEEDED

    def check_channel_forward_agent_request(self, channel):
        logger.debug("Forward agent key request for channel %s" % (channel,))
        gevent.sleep()
        return True

    def check_channel_exec_request(self, channel, cmd):
        logger.debug("Got exec request on channel %s for cmd %s" % (channel, cmd,))
        self.event.set()
        _env = os.environ
        _env['PYTHONIOENCODING'] = self.encoding
        if hasattr(channel, 'environment'):
            _env.update(channel.environment)
        process = gevent.subprocess.Popen(cmd, stdout=gevent.subprocess.PIPE,
                                          stdin=gevent.subprocess.PIPE,
                                          stderr=gevent.subprocess.PIPE,
                                          shell=True, env=_env)
        gevent.spawn(self._read_response, channel, process)
        gevent.sleep(0)
        return True

    def check_channel_env_request(self, channel, name, value):
        if not hasattr(channel, 'environment'):
            channel.environment = {}
        channel.environment.update({
            name.decode(self.encoding): value.decode(self.encoding)})
        return True

    def _read_response(self, channel, process):
        gevent.sleep(0)
        logger.debug("Waiting for output")
        for line in process.stdout:
            channel.send(line)
        for line in process.stderr:
            channel.send_stderr(line)
        process.communicate()
        channel.send_exit_status(process.returncode)
        logger.debug("Command finished with return code %s", process.returncode)
        # Let clients consume output from channel before closing
        gevent.sleep(.1)
        channel.close()
        gevent.sleep(0)
예제 #58
0
class MDSThrasher(Greenlet):
    """
    MDSThrasher::

    The MDSThrasher thrashes MDSs during execution of other tasks (workunits, etc).

    The config is optional.  Many of the config parameters are a a maximum value
    to use when selecting a random value from a range.  To always use the maximum
    value, set no_random to true.  The config is a dict containing some or all of:

    max_thrash: [default: 1] the maximum number of active MDSs per FS that will be thrashed at
      any given time.

    max_thrash_delay: [default: 30] maximum number of seconds to delay before
      thrashing again.

    max_replay_thrash_delay: [default: 4] maximum number of seconds to delay while in
      the replay state before thrashing.

    max_revive_delay: [default: 10] maximum number of seconds to delay before
      bringing back a thrashed MDS.

    randomize: [default: true] enables randomization and use the max/min values

    seed: [no default] seed the random number generator

    thrash_in_replay: [default: 0.0] likelihood that the MDS will be thrashed
      during replay.  Value should be between 0.0 and 1.0.

    thrash_max_mds: [default: 0.05] likelihood that the max_mds of the mds
      cluster will be modified to a value [1, current) or (current, starting
      max_mds]. When reduced, randomly selected MDSs other than rank 0 will be
      deactivated to reach the new max_mds.  Value should be between 0.0 and 1.0.

    thrash_while_stopping: [default: false] thrash an MDS while there
      are MDS in up:stopping (because max_mds was changed and some
      MDS were deactivated).

    thrash_weights: allows specific MDSs to be thrashed more/less frequently.
      This option overrides anything specified by max_thrash.  This option is a
      dict containing mds.x: weight pairs.  For example, [mds.a: 0.7, mds.b:
      0.3, mds.c: 0.0].  Each weight is a value from 0.0 to 1.0.  Any MDSs not
      specified will be automatically given a weight of 0.0 (not thrashed).
      For a given MDS, by default the trasher delays for up to
      max_thrash_delay, trashes, waits for the MDS to recover, and iterates.
      If a non-zero weight is specified for an MDS, for each iteration the
      thrasher chooses whether to thrash during that iteration based on a
      random value [0-1] not exceeding the weight of that MDS.

    Examples::


      The following example sets the likelihood that mds.a will be thrashed
      to 80%, mds.b to 20%, and other MDSs will not be thrashed.  It also sets the
      likelihood that an MDS will be thrashed in replay to 40%.
      Thrash weights do not have to sum to 1.

      tasks:
      - ceph:
      - mds_thrash:
          thrash_weights:
            - mds.a: 0.8
            - mds.b: 0.2
          thrash_in_replay: 0.4
      - ceph-fuse:
      - workunit:
          clients:
            all: [suites/fsx.sh]

      The following example disables randomization, and uses the max delay values:

      tasks:
      - ceph:
      - mds_thrash:
          max_thrash_delay: 10
          max_revive_delay: 1
          max_replay_thrash_delay: 4

    """
    def __init__(self, ctx, manager, config, fs, max_mds):
        Greenlet.__init__(self)

        self.config = config
        self.ctx = ctx
        self.e = None
        self.logger = log.getChild('fs.[{f}]'.format(f=fs.name))
        self.fs = fs
        self.manager = manager
        self.max_mds = max_mds
        self.name = 'thrasher.fs.[{f}]'.format(f=fs.name)
        self.stopping = Event()

        self.randomize = bool(self.config.get('randomize', True))
        self.thrash_max_mds = float(self.config.get('thrash_max_mds', 0.05))
        self.max_thrash = int(self.config.get('max_thrash', 1))
        self.max_thrash_delay = float(self.config.get('thrash_delay', 120.0))
        self.thrash_in_replay = float(
            self.config.get('thrash_in_replay', False))
        assert self.thrash_in_replay >= 0.0 and self.thrash_in_replay <= 1.0, 'thrash_in_replay ({v}) must be between [0.0, 1.0]'.format(
            v=self.thrash_in_replay)
        self.max_replay_thrash_delay = float(
            self.config.get('max_replay_thrash_delay', 4.0))
        self.max_revive_delay = float(self.config.get('max_revive_delay',
                                                      10.0))

    def _run(self):
        try:
            self.do_thrash()
        except Exception as e:
            # Log exceptions here so we get the full backtrace (gevent loses them).
            # Also allow succesful completion as gevent exception handling is a broken mess:
            #
            # 2017-02-03T14:34:01.259 CRITICAL:root:  File "gevent.libev.corecext.pyx", line 367, in gevent.libev.corecext.loop.handle_error (src/gevent/libev/gevent.corecext.c:5051)
            #   File "/home/teuthworker/src/git.ceph.com_git_teuthology_master/virtualenv/local/lib/python2.7/site-packages/gevent/hub.py", line 558, in handle_error
            #     self.print_exception(context, type, value, tb)
            #   File "/home/teuthworker/src/git.ceph.com_git_teuthology_master/virtualenv/local/lib/python2.7/site-packages/gevent/hub.py", line 605, in print_exception
            #     traceback.print_exception(type, value, tb, file=errstream)
            #   File "/usr/lib/python2.7/traceback.py", line 124, in print_exception
            #     _print(file, 'Traceback (most recent call last):')
            #   File "/usr/lib/python2.7/traceback.py", line 13, in _print
            #     file.write(str+terminator)
            # 2017-02-03T14:34:01.261 CRITICAL:root:IOError
            self.e = e
            self.logger.exception("exception:")
            # allow successful completion so gevent doesn't see an exception...

    def log(self, x):
        """Write data to logger assigned to this MDThrasher"""
        self.logger.info(x)

    def stop(self):
        self.stopping.set()

    def kill_mds(self, mds):
        if self.config.get('powercycle'):
            (remote, ) = (self.ctx.cluster.only(
                'mds.{m}'.format(m=mds)).remotes.iterkeys())
            self.log('kill_mds on mds.{m} doing powercycle of {s}'.format(
                m=mds, s=remote.name))
            self._assert_ipmi(remote)
            remote.console.power_off()
        else:
            self.ctx.daemons.get_daemon('mds', mds).stop()

    @staticmethod
    def _assert_ipmi(remote):
        assert remote.console.has_ipmi_credentials, (
            "powercycling requested but RemoteConsole is not "
            "initialized.  Check ipmi config.")

    def revive_mds(self, mds, standby_for_rank=None):
        """
        Revive mds -- do an ipmpi powercycle (if indicated by the config)
        and then restart (using --hot-standby if specified.
        """
        if self.config.get('powercycle'):
            (remote, ) = (self.ctx.cluster.only(
                'mds.{m}'.format(m=mds)).remotes.iterkeys())
            self.log('revive_mds on mds.{m} doing powercycle of {s}'.format(
                m=mds, s=remote.name))
            self._assert_ipmi(remote)
            remote.console.power_on()
            self.manager.make_admin_daemon_dir(self.ctx, remote)
        args = []
        if standby_for_rank:
            args.extend(['--hot-standby', standby_for_rank])
        self.ctx.daemons.get_daemon('mds', mds).restart(*args)

    def wait_for_stable(self, rank=None, gid=None):
        self.log('waiting for mds cluster to stabilize...')
        for itercount in itertools.count():
            status = self.fs.status()
            max_mds = status.get_fsmap(self.fs.id)['mdsmap']['max_mds']
            ranks = list(status.get_ranks(self.fs.id))
            stopping = filter(lambda info: "up:stopping" == info['state'],
                              ranks)
            actives = filter(
                lambda info: "up:active" == info['state'] and "laggy_since"
                not in info, ranks)

            if not bool(self.config.get('thrash_while_stopping',
                                        False)) and len(stopping) > 0:
                if itercount % 5 == 0:
                    self.log(
                        'cluster is considered unstable while MDS are in up:stopping (!thrash_while_stopping)'
                    )
            else:
                if rank is not None:
                    try:
                        info = status.get_rank(self.fs.id, rank)
                        if info['gid'] != gid and "up:active" == info['state']:
                            self.log(
                                'mds.{name} has gained rank={rank}, replacing gid={gid}'
                                .format(name=info['name'], rank=rank, gid=gid))
                            return status
                    except:
                        pass  # no rank present
                    if len(actives) >= max_mds:
                        # no replacement can occur!
                        self.log(
                            "cluster has %d actives (max_mds is %d), no MDS can replace rank %d"
                            .format(len(actives), max_mds, rank))
                        return status
                else:
                    if len(actives) >= max_mds:
                        self.log(
                            'mds cluster has {count} alive and active, now stable!'
                            .format(count=len(actives)))
                        return status, None
            if itercount > 300 / 2:  # 5 minutes
                raise RuntimeError('timeout waiting for cluster to stabilize')
            elif itercount % 5 == 0:
                self.log('mds map: {status}'.format(status=status))
            else:
                self.log('no change')
            sleep(2)

    def do_thrash(self):
        """
        Perform the random thrashing action
        """

        self.log('starting mds_do_thrash for fs {fs}'.format(fs=self.fs.name))
        stats = {
            "max_mds": 0,
            "deactivate": 0,
            "kill": 0,
        }

        while not self.stopping.is_set():
            delay = self.max_thrash_delay
            if self.randomize:
                delay = random.randrange(0.0, self.max_thrash_delay)

            if delay > 0.0:
                self.log('waiting for {delay} secs before thrashing'.format(
                    delay=delay))
                self.stopping.wait(delay)
                if self.stopping.is_set():
                    continue

            status = self.fs.status()

            if random.random() <= self.thrash_max_mds:
                max_mds = status.get_fsmap(self.fs.id)['mdsmap']['max_mds']
                options = range(1, max_mds) + range(max_mds + 1,
                                                    self.max_mds + 1)
                if len(options) > 0:
                    sample = random.sample(options, 1)
                    new_max_mds = sample[0]
                    self.log('thrashing max_mds: %d -> %d' %
                             (max_mds, new_max_mds))
                    self.fs.set_max_mds(new_max_mds)
                    stats['max_mds'] += 1

                    targets = filter(lambda r: r['rank'] >= new_max_mds,
                                     status.get_ranks(self.fs.id))
                    if len(targets) > 0:
                        # deactivate mds in decending order
                        targets = sorted(targets,
                                         key=lambda r: r['rank'],
                                         reverse=True)
                        for target in targets:
                            self.log("deactivating rank %d" % target['rank'])
                            self.fs.deactivate(target['rank'])
                            stats['deactivate'] += 1
                            status = self.wait_for_stable()[0]
                    else:
                        status = self.wait_for_stable()[0]

            count = 0
            for info in status.get_ranks(self.fs.id):
                name = info['name']
                label = 'mds.' + name
                rank = info['rank']
                gid = info['gid']

                # if thrash_weights isn't specified and we've reached max_thrash,
                # we're done
                count = count + 1
                if 'thrash_weights' not in self.config and count > self.max_thrash:
                    break

                weight = 1.0
                if 'thrash_weights' in self.config:
                    weight = self.config['thrash_weights'].get(label, '0.0')
                skip = random.randrange(0.0, 1.0)
                if weight <= skip:
                    self.log(
                        'skipping thrash iteration with skip ({skip}) > weight ({weight})'
                        .format(skip=skip, weight=weight))
                    continue

                self.log('kill {label} (rank={rank})'.format(label=label,
                                                             rank=rank))
                self.kill_mds(name)
                stats['kill'] += 1

                # wait for mon to report killed mds as crashed
                last_laggy_since = None
                itercount = 0
                while True:
                    status = self.fs.status()
                    info = status.get_mds(name)
                    if not info:
                        break
                    if 'laggy_since' in info:
                        last_laggy_since = info['laggy_since']
                        break
                    if any([(f == name) for f in status.get_fsmap(self.fs.id)
                            ['mdsmap']['failed']]):
                        break
                    self.log(
                        'waiting till mds map indicates {label} is laggy/crashed, in failed state, or {label} is removed from mdsmap'
                        .format(label=label))
                    itercount = itercount + 1
                    if itercount > 10:
                        self.log('mds map: {status}'.format(status=status))
                    sleep(2)

                if last_laggy_since:
                    self.log(
                        '{label} reported laggy/crashed since: {since}'.format(
                            label=label, since=last_laggy_since))
                else:
                    self.log('{label} down, removed from mdsmap'.format(
                        label=label, since=last_laggy_since))

                # wait for a standby mds to takeover and become active
                status = self.wait_for_stable(rank, gid)

                # wait for a while before restarting old active to become new
                # standby
                delay = self.max_revive_delay
                if self.randomize:
                    delay = random.randrange(0.0, self.max_revive_delay)

                self.log(
                    'waiting for {delay} secs before reviving {label}'.format(
                        delay=delay, label=label))
                sleep(delay)

                self.log('reviving {label}'.format(label=label))
                self.revive_mds(name)

                for itercount in itertools.count():
                    if itercount > 300 / 2:  # 5 minutes
                        raise RuntimeError('timeout waiting for MDS to revive')
                    status = self.fs.status()
                    info = status.get_mds(name)
                    if info and info['state'] in ('up:standby',
                                                  'up:standby-replay',
                                                  'up:active'):
                        self.log('{label} reported in {state} state'.format(
                            label=label, state=info['state']))
                        break
                    self.log(
                        'waiting till mds map indicates {label} is in active, standby or standby-replay'
                        .format(label=label))
                    sleep(2)

        for stat in stats:
            self.log("stat['{key}'] = {value}".format(key=stat,
                                                      value=stats[stat]))
예제 #59
0
class MonitorNetwork(Jobmanager, NodeMonitorMixin):
    one_min_stats = ['work_restarts', 'new_jobs', 'work_pushes']
    defaults = config = dict(coinservs=REQUIRED,
                             diff1=0x0000FFFF00000000000000000000000000000000000000000000000000000000,
                             hashes_per_share=0xFFFF,
                             merged=tuple(),
                             block_poll=0.2,
                             job_refresh=15,
                             rpc_ping_int=2,
                             pow_block_hash=False,
                             poll=None,
                             currency=REQUIRED,
                             algo=REQUIRED,
                             pool_address='',
                             coinbase_string="",
                             signal=None,
                             payout_drk_mn=True,
                             max_blockheight=None)

    def __init__(self, config):
        NodeMonitorMixin.__init__(self)
        self._configure(config)
        if get_bcaddress_version(self.config['pool_address']) is None:
            raise ConfigurationError("No valid pool address configured! Exiting.")

        # Since some MonitorNetwork objs are polling and some aren't....
        self.gl_methods = ['_monitor_nodes', '_check_new_jobs']

        # Aux network monitors (merged mining)
        self.auxmons = []

        # internal vars
        self._last_gbt = {}
        self._job_counter = 0  # a unique job ID counter

        # Currently active jobs keyed by their unique ID
        self.jobs = {}
        self.stale_jobs = deque([], maxlen=10)
        self.latest_job = None  # The last job that was generated
        self.new_job = Event()
        self.last_signal = 0.0

        # general current network stats
        self.current_net = dict(difficulty=None,
                                height=None,
                                last_block=0.0,
                                prev_hash=None,
                                transactions=None,
                                subsidy=None)
        self.block_stats = dict(accepts=0,
                                rejects=0,
                                solves=0,
                                last_solve_height=None,
                                last_solve_time=None,
                                last_solve_worker=None)
        self.recent_blocks = deque(maxlen=15)

        # Run the looping height poller if we aren't getting push notifications
        if (not self.config['signal'] and self.config['poll'] is None) or self.config['poll']:
            self.gl_methods.append('_poll_height')

    @property
    def status(self):
        """ For display in the http monitor """
        ret = dict(net_state=self.current_net,
                   block_stats=self.block_stats,
                   last_signal=self.last_signal,
                   currency=self.config['currency'],
                   live_coinservers=len(self._live_connections),
                   down_coinservers=len(self._down_connections),
                   coinservers={},
                   job_count=len(self.jobs))
        for connection in self._live_connections:
            st = connection.status()
            st['status'] = 'live'
            ret['coinservers'][connection.name] = st
        for connection in self._down_connections:
            st = connection.status()
            st['status'] = 'down'
            ret['coinservers'][connection.name] = st
        return ret

    def start(self):
        Jobmanager.start(self)

        if self.config['signal']:
            self.logger.info("Listening for push block notifs on signal {}"
                             .format(self.config['signal']))
            gevent.signal(self.config['signal'], self.getblocktemplate, signal=True)

        # Find desired auxmonitors
        self.config['merged'] = set(self.config['merged'])
        found_merged = set()

        for mon in self.manager.component_types['Jobmanager']:
            if mon.key in self.config['merged']:
                self.auxmons.append(mon)
                found_merged.add(mon.key)
                mon.new_job.rawlink(self.new_merged_work)

        for monitor in self.config['merged'] - found_merged:
            self.logger.error("Unable to locate Auxmonitor(s) '{}'".format(monitor))

    def found_block(self, raw_coinbase, address, worker, hash_hex, header, job, start):
        """ Submit a valid block (hopefully!) to the RPC servers """
        block = hexlify(job.submit_serial(header, raw_coinbase=raw_coinbase))
        result = {}

        def record_outcome(success):
            # If we've already recorded a result, then return
            if result:
                return

            if start:
                submission_time = time.time() - start
                self.logger.info(
                    "Recording block submission outcome {} after {}"
                    .format(success, submission_time))
                if success:
                    self.manager.log_event(
                        "{name}.block_submission_{curr}:{t}|ms"
                        .format(name=self.manager.config['procname'],
                                curr=self.config['currency'],
                                t=submission_time * 1000))

            if success:
                self.block_stats['accepts'] += 1
                self.recent_blocks.append(
                    dict(height=job.block_height, timestamp=int(time.time())))
            else:
                self.block_stats['rejects'] += 1
                self.logger.info("{} BLOCK {}:{} REJECTED"
                                 .format(self.config['currency'], hash_hex,
                                         job.block_height))

            result.update(dict(
                address=address,
                height=job.block_height,
                total_subsidy=job.total_value,
                fees=job.fee_total,
                hex_bits=hexlify(job.bits),
                hex_hash=hash_hex,
                worker=worker,
                algo=job.algo,
                merged=False,
                success=success,
                currency=self.config['currency']
            ))

        def submit_block(conn):
            retries = 0
            while retries < 5:
                retries += 1
                res = "failed"
                try:
                    res = conn.submitblock(block)
                except (CoinRPCException, socket.error, ValueError) as e:
                    self.logger.info("Block failed to submit to the server {} with submitblock! {}"
                                     .format(conn.name, e))
                    if getattr(e, 'error', {}).get('code', 0) != -8:
                        self.logger.error(getattr(e, 'error'), exc_info=True)
                    try:
                        res = conn.getblocktemplate({'mode': 'submit', 'data': block})
                    except (CoinRPCException, socket.error, ValueError) as e:
                        self.logger.error("Block failed to submit to the server {}!"
                                          .format(conn.name), exc_info=True)
                        self.logger.error(getattr(e, 'error'))

                if res is None:
                    self.logger.info("{} BLOCK {}:{} accepted by {}"
                                     .format(self.config['currency'], hash_hex,
                                             job.block_height, conn.name))
                    record_outcome(True)
                    break  # break retry loop if success
                else:
                    self.logger.error(
                        "Block failed to submit to the server {}, "
                        "server returned {}!".format(conn.name, res),
                        exc_info=True)
                sleep(1)
                self.logger.info("Retry {} for connection {}".format(retries, conn.name))

        for tries in xrange(200):
            if not self._live_connections:
                self.logger.error("No live connections to submit new block to!"
                                  " Retry {} / 200.".format(tries))
                sleep(0.1)
                continue

            gl = []
            for conn in self._live_connections:
                # spawn a new greenlet for each submission to do them all async.
                # lower orphan chance
                gl.append(spawn(submit_block, conn))

            gevent.joinall(gl)
            # If none of the submission threads were successfull then record a
            # failure
            if not result:
                record_outcome(False)
            break

        self.logger.log(35, "Valid network block identified!")
        self.logger.info("New block at height {} with hash {} and subsidy {}"
                         .format(job.block_height,
                                 hash_hex,
                                 job.total_value))

        self.block_stats['solves'] += 1
        self.block_stats['last_solve_hash'] = hash_hex
        self.block_stats['last_solve_height'] = job.block_height
        self.block_stats['last_solve_worker'] = "{}.{}".format(address, worker)
        self.block_stats['last_solve_time'] = datetime.datetime.utcnow()

        if __debug__:
            self.logger.debug("New block hex dump:\n{}".format(block))
            self.logger.debug("Coinbase: {}".format(str(job.coinbase.to_dict())))
            for trans in job.transactions:
                self.logger.debug(str(trans.to_dict()))

        # Pass back all the results to the reporter who's waiting
        return result

    @loop(interval='block_poll')
    def _poll_height(self):
        try:
            height = self.call_rpc('getblockcount')
        except RPCException:
            return

        if self.current_net['height'] != height:
            self.logger.info("New block on main network detected with polling")
            self.current_net['height'] = height
            self.getblocktemplate(new_block=True)

    @loop(interval='job_refresh')
    def _check_new_jobs(self):
        self.getblocktemplate()

    def getblocktemplate(self, new_block=False, signal=False):
        if signal:
            self.last_signal = time.time()
        try:
            # request local memory pool and load it in
            bt = self.call_rpc('getblocktemplate',
                               {'capabilities': [
                                   'coinbasevalue',
                                   'coinbase/append',
                                   'coinbase',
                                   'generation',
                                   'time',
                                   'transactions/remove',
                                   'prevblock',
                               ]})
        except RPCException:
            return False

        if self._last_gbt.get('height') != bt['height']:
            new_block = True
        # If this was from a push signal and the
        if signal and new_block:
            self.logger.info("Push block signal notified us of a new block!")
        elif signal:
            self.logger.info("Push block signal notified us of a block we "
                             "already know about!")
            return

        # generate a new job if we got some new work!
        dirty = False
        if bt != self._last_gbt:
            self._last_gbt = bt
            self._last_gbt['update_time'] = time.time()
            dirty = True

        if new_block or dirty:
            # generate a new job and push it if there's a new block on the
            # network
            self.generate_job(push=new_block, flush=new_block, new_block=new_block)

    def new_merged_work(self, event):
        self.generate_job(push=True, flush=event.flush, network='aux')

    def generate_job(self, push=False, flush=False, new_block=False, network='main'):
        """ Creates a new job for miners to work on. Push will trigger an
        event that sends new work but doesn't force a restart. If flush is
        true a job restart will be triggered. """

        # aux monitors will often call this early when not needed at startup
        if not self._last_gbt:
            self.logger.warn("Cannot generate new job, missing last GBT info")
            return

        if self.auxmons:
            merged_work = {}
            auxdata = {}
            for auxmon in self.auxmons:
                if auxmon.last_work['hash'] is None:
                    continue
                merged_work[auxmon.last_work['chainid']] = dict(
                    hash=auxmon.last_work['hash'],
                    target=auxmon.last_work['type']
                )

            tree, size = bitcoin_data.make_auxpow_tree(merged_work)
            mm_hashes = [merged_work.get(tree.get(i), dict(hash=0))['hash']
                         for i in xrange(size)]
            mm_data = '\xfa\xbemm'
            mm_data += bitcoin_data.aux_pow_coinbase_type.pack(dict(
                merkle_root=bitcoin_data.merkle_hash(mm_hashes),
                size=size,
                nonce=0,
            ))

            for auxmon in self.auxmons:
                if auxmon.last_work['hash'] is None:
                    continue
                data = dict(target=auxmon.last_work['target'],
                            hash=auxmon.last_work['hash'],
                            height=auxmon.last_work['height'],
                            found_block=auxmon.found_block,
                            index=mm_hashes.index(auxmon.last_work['hash']),
                            type=auxmon.last_work['type'],
                            hashes=mm_hashes)
                auxdata[auxmon.config['currency']] = data
        else:
            auxdata = {}
            mm_data = None

        # here we recalculate the current merkle branch and partial
        # coinbases for passing to the mining clients
        coinbase = Transaction()
        coinbase.version = 2
        # create a coinbase input with encoded height and padding for the
        # extranonces so script length is accurate
        extranonce_length = (self.manager.config['extranonce_size'] +
                             self.manager.config['extranonce_serv_size'])
        coinbase.inputs.append(
            Input.coinbase(self._last_gbt['height'],
                           addtl_push=[mm_data] if mm_data else [],
                           extra_script_sig=b'\0' * extranonce_length,
                           desc_string=self.config['coinbase_string']))

        coinbase_value = self._last_gbt['coinbasevalue']

        # Payout Darkcoin masternodes
        mn_enforcement = self._last_gbt.get('enforce_masternode_payments', True)
        if (self.config['payout_drk_mn'] is True or mn_enforcement is True) \
                and self._last_gbt.get('payee', '') != '':
            # Grab the darkcoin payout amount, default to 20%
            payout = self._last_gbt.get('payee_amount', coinbase_value / 5)
            coinbase_value -= payout
            coinbase.outputs.append(
                Output.to_address(payout, self._last_gbt['payee']))
            self.logger.debug(
                "Created TX output for masternode at ({}:{}). Coinbase value "
                "reduced to {}".format(self._last_gbt['payee'], payout,
                                       coinbase_value))

        # simple output to the proper address and value
        coinbase.outputs.append(
            Output.to_address(coinbase_value, self.config['pool_address']))

        job_id = hexlify(struct.pack(str("I"), self._job_counter))
        bt_obj = BlockTemplate.from_gbt(self._last_gbt,
                                        coinbase,
                                        extranonce_length,
                                        [Transaction(unhexlify(t['data']), fees=t['fee'])
                                         for t in self._last_gbt['transactions']])
        # add in our merged mining data
        if mm_data:
            hashes = [bitcoin_data.hash256(tx.raw) for tx in bt_obj.transactions]
            bt_obj.merkle_link = bitcoin_data.calculate_merkle_link([None] + hashes, 0)
        bt_obj.merged_data = auxdata
        bt_obj.job_id = job_id
        bt_obj.diff1 = self.config['diff1']
        bt_obj.algo = self.config['algo']
        bt_obj.currency = self.config['currency']
        bt_obj.pow_block_hash = self.config['pow_block_hash']
        bt_obj.block_height = self._last_gbt['height']
        bt_obj.acc_shares = set()
        if flush:
            bt_obj.type = 0
        elif push:
            bt_obj.type = 1
        else:
            bt_obj.type = 2
        bt_obj.found_block = self.found_block

        # Push the fresh job to users after updating details
        self._job_counter += 1
        if flush:
            self.jobs.clear()
        self.jobs[job_id] = bt_obj
        self.latest_job = bt_obj

        self.new_job.job = bt_obj
        self.new_job.set()
        self.new_job.clear()
        event = ("{name}.jobmanager.new_job:1|c\n"
                 .format(name=self.manager.config['procname']))
        if push or flush:
            self.logger.info(
                "{}: New block template with {:,} trans. "
                "Diff {:,.4f}. Subsidy {:,.2f}. Height {:,}. Merged: {}"
                .format("FLUSH" if flush else "PUSH",
                        len(self._last_gbt['transactions']),
                        bits_to_difficulty(self._last_gbt['bits']),
                        self._last_gbt['coinbasevalue'] / 100000000.0,
                        self._last_gbt['height'],
                        ', '.join(auxdata.keys())))
            event += ("{name}.jobmanager.work_push:1|c\n"
                      .format(name=self.manager.config['procname']))

        # Stats and notifications now that it's pushed
        if flush:
            event += ("{name}.jobmanager.work_restart:1|c\n"
                      .format(name=self.manager.config['procname']))
            self.logger.info("New {} network block announced! Wiping previous"
                             " jobs and pushing".format(network))
        elif push:
            self.logger.info("New {} network block announced, pushing new job!"
                             .format(network))

        if new_block:
            hex_bits = hexlify(bt_obj.bits)
            self.current_net['difficulty'] = bits_to_difficulty(hex_bits)
            self.current_net['subsidy'] = bt_obj.total_value
            self.current_net['height'] = bt_obj.block_height - 1
            self.current_net['last_block'] = time.time()
            self.current_net['prev_hash'] = bt_obj.hashprev_be_hex
            self.current_net['transactions'] = len(bt_obj.transactions)

            event += (
                "{name}.{curr}.difficulty:{diff}|g\n"
                "{name}.{curr}.subsidy:{subsidy}|g\n"
                "{name}.{curr}.job_generate:{t}|g\n"
                "{name}.{curr}.height:{height}|g"
                .format(name=self.manager.config['procname'],
                        curr=self.config['currency'],
                        diff=self.current_net['difficulty'],
                        subsidy=bt_obj.total_value,
                        height=bt_obj.block_height - 1,
                        t=(time.time() - self._last_gbt['update_time']) * 1000))
        self.manager.log_event(event)
예제 #60
0
class InboundESL(object):
    def __init__(self, host, port, password):
        self.host = host
        self.port = port
        self.password = password
        self.timeout = 5
        self._run = True
        self._EOL = '\n'
        self._commands_sent = []
        self._auth_request_event = Event()
        self._receive_events_greenlet = None
        self._process_events_greenlet = None
        self.event_handlers = {}
        self.connected = False

        self._esl_event_queue = Queue()
        self._process_esl_event_queue = True

    def connect(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.settimeout(self.timeout)
        self.sock.connect((self.host, self.port))
        self.connected = True
        self.sock.settimeout(None)
        self.sock_file = self.sock.makefile()
        self._receive_events_greenlet = gevent.spawn(self.receive_events)
        self._process_events_greenlet = gevent.spawn(self.process_events)
        self._auth_request_event.wait()
        if not self.connected:
            raise NotConnectedError('Server closed connection, check FreeSWITCH config.')
        self.authenticate()

    def receive_events(self):
        buf = ''
        while self._run:
            try:
                data = self.sock_file.readline()
            except Exception:
                self._run = False
                self.connected = False
                self.sock.close()
                # logging.exception("Error reading from socket.")
                break
            if not data:
                if self.connected:
                    logging.error("Error receiving data, is FreeSWITCH running?")
                    self.connected = False
                break
            # Empty line
            if data == self._EOL:
                event = ESLEvent(buf)
                buf = ''
                self.handle_event(event)
                continue
            buf += data

    @staticmethod
    def _read_socket(sock, length):
        """Receive data from socket until the length is reached."""
        data = sock.read(length)
        data_length = len(data)
        while data_length < length:
            logging.warn(
                'Socket should read %s bytes, but actually read %s bytes. '
                'Consider increasing "net.core.rmem_default".' %
                (length, data_length)
            )
            # FIXME(italo): if not data raise error
            data += sock.read(length - data_length)
            data_length = len(data)
        return data

    def handle_event(self, event):
        if event.headers['Content-Type'] == 'auth/request':
            self._auth_request_event.set()
        elif event.headers['Content-Type'] == 'command/reply':
            async_response = self._commands_sent.pop(0)
            event.data = event.headers['Reply-Text']
            async_response.set(event)
        elif event.headers['Content-Type'] == 'api/response':
            length = int(event.headers['Content-Length'])
            data = self._read_socket(self.sock_file, length)
            event.data = data
            async_response = self._commands_sent.pop(0)
            async_response.set(event)
        elif event.headers['Content-Type'] == 'text/disconnect-notice':
            self.connected = False
        elif event.headers['Content-Type'] == 'text/rude-rejection':
            self.connected = False
            length = int(event.headers['Content-Length'])
            self._read_socket(self.sock_file, length)
            self._auth_request_event.set()
        else:
            length = int(event.headers['Content-Length'])
            data = self._read_socket(self.sock_file, length)
            if event.headers.get('Content-Type') == 'log/data':
                event.data = data
            else:
                event.parse_data(data)
            self._esl_event_queue.put(event)

    def _safe_exec_handler(self, handler, event):
        try:
            handler(event)
        except:
            logging.exception('ESL %s raised exception.' % handler.__name__)
            logging.error(pprint.pformat(event.headers))

    def process_events(self):
        logging.debug('Event Processor Running')
        while self._run:
            if not self._process_esl_event_queue:
                gevent.sleep(1)
                continue

            try:
                event = self._esl_event_queue.get(timeout=1)
            except gevent.queue.Empty:
                continue

            if event.headers.get('Event-Name') == 'CUSTOM':
                handlers = self.event_handlers.get(event.headers.get('Event-Subclass'))
            else:
                handlers = self.event_handlers.get(event.headers.get('Event-Name'))

            if not handlers and event.headers.get('Content-Type') == 'log/data':
                handlers = self.event_handlers.get('log')

            if not handlers:
                continue

            if hasattr(self, 'before_handle'):
                self._safe_exec_handler(self.before_handle, event)

            for handle in handlers:
                self._safe_exec_handler(handle, event)

            if hasattr(self, 'after_handle'):
                self._safe_exec_handler(self.after_handle, event)

    def send(self, data):
        if not self.connected:
            raise NotConnectedError()
        async_response = gevent.event.AsyncResult()
        self._commands_sent.append(async_response)
        raw_msg = (data + self._EOL*2).encode('utf-8')
        self.sock.send(raw_msg)
        response = async_response.get()
        return response

    def authenticate(self):
        response = self.send('auth %s' % self.password)
        if response.headers['Reply-Text'] != '+OK accepted':
            raise ValueError('Invalid password.')

    def register_handle(self, name, handler):
        if name not in self.event_handlers:
            self.event_handlers[name] = []
        if handler in self.event_handlers[name]:
            return
        self.event_handlers[name].append(handler)

    def unregister_handle(self, name, handler):
        if name not in self.event_handlers:
            raise ValueError('No handlers found for event: %s' % name)
        self.event_handlers[name].remove(handler)
        if not self.event_handlers[name]:
            del self.event_handlers[name]

    def stop(self):
        if self.connected:
            self.send('exit')
        self._run = False
        logging.info("Waiting for receive greenlet exit")
        self._receive_events_greenlet.join()
        logging.info("Waiting for event processing greenlet exit")
        self._process_events_greenlet.join()
        if self.connected:
            self.sock.close()
            self.sock_file.close()