예제 #1
0
class ChatClient(Client):
    def __init__(self, *args, **kw):
        Client.__init__(self, *args, **kw)
        self.input = Queue()

    def read_chat_message(self, prompt):
        msg = raw_input(prompt)
        return msg

    def input_handler(self):
        nick = thread(self.read_chat_message, "nick: ").strip()
        self.nick = nick
        self.input.put(nick)
        while True:
            msg = thread(self.read_chat_message, "").strip()
            self.input.put(msg)

    @call
    def chat(self):
        fork(self.input_handler)
        nick = self.input.get()
        send("%s\r\n" % nick)
        while True:
            evt, data = first(until_eol=True, waits=[self.input])
            if evt == "until_eol":
                print data.strip()
            else:
                send("%s\r\n" % data)
예제 #2
0
    def websocket_protocol(self, req):
        """Runs the WebSocket protocol after the handshake is complete.

        Creates two `Queue` instances for incoming and outgoing messages and
        passes them to the `web_socket_handler` that was supplied to the
        `WebSocketServer` constructor.

        """
        inq = Queue()
        outq = Queue()

        def wrap(req, inq, outq):
            self.web_socket_handler(req, inq, outq)
            outq.put(WebSocketDisconnect())

        handler_loop = fork(wrap, req, inq, outq)

        if req.rfc_handshake:
            handle_frames = self.handle_rfc_6455_frames
        else:
            handle_frames = self.handle_non_rfc_frames

        try:
            handle_frames(inq, outq)
        except ConnectionClosed:
            if handler_loop.running:
                inq.put(WebSocketDisconnect())
            raise
예제 #3
0
class ChatClient(Client):
    def __init__(self, *args, **kw):
        Client.__init__(self, *args, **kw)
        self.input = Queue()

    def read_chat_message(self, prompt):
        msg = raw_input(prompt)
        return msg

    def input_handler(self):
        nick = thread(self.read_chat_message, "nick: ").strip()
        self.nick = nick
        self.input.put(nick)
        while True:
            msg = thread(self.read_chat_message, "").strip()
            self.input.put(msg)

    @call
    def chat(self):
        fork(self.input_handler)
        nick = self.input.get()
        send("%s\r\n" % nick)
        while True:
            evt, data = first(until_eol=True, waits=[self.input])
            if evt == "until_eol":
                print data.strip()
            else:
                send("%s\r\n" % data)
예제 #4
0
파일: test_queue.py 프로젝트: 1angxi/diesel
class TestQueueTimeouts(object):
    def setup(self):
        self.result = Event()
        self.queue = Queue()
        self.timeouts = 0
        diesel.fork(self.consumer, 0.01)
        diesel.fork(self.producer, 0.05)
        diesel.fork(self.consumer, 0.10)
        ev, val = diesel.first(sleep=TIMEOUT, waits=[self.result])
        if ev == 'sleep':
            assert 0, 'timed out'

    def consumer(self, timeout):
        try:
            self.queue.get(timeout=timeout)
            self.result.set()
        except QueueTimeout:
            self.timeouts += 1

    def producer(self, delay):
        diesel.sleep(delay)
        self.queue.put('test')

    def test_a_consumer_timed_out(self):
        assert self.timeouts == 1

    def test_a_consumer_got_a_value(self):
        assert self.result.is_set
예제 #5
0
파일: test_queue.py 프로젝트: yadra/diesel
class TestQueueTimeouts(object):
    def setup(self):
        self.result = Event()
        self.queue = Queue()
        self.timeouts = 0
        diesel.fork(self.consumer, 0.01)
        diesel.fork(self.producer, 0.05)
        diesel.fork(self.consumer, 0.10)
        ev, val = diesel.first(sleep=TIMEOUT, waits=[self.result])
        if ev == 'sleep':
            assert 0, 'timed out'

    def consumer(self, timeout):
        try:
            self.queue.get(timeout=timeout)
            self.result.set()
        except QueueTimeout:
            self.timeouts += 1

    def producer(self, delay):
        diesel.sleep(delay)
        self.queue.put('test')

    def test_a_consumer_timed_out(self):
        assert self.timeouts == 1

    def test_a_consumer_got_a_value(self):
        assert self.result.is_set
예제 #6
0
class ThreadPool(object):
    def __init__(self, concurrency, handler, generator, finalizer=None):
        self.concurrency = concurrency
        self.handler = handler
        self.generator = generator
        self.finalizer = finalizer

    def handler_wrap(self):
        try:
            label("thread-pool-%s" % self.handler)
            while True:
                self.waiting += 1
                if self.waiting == 1:
                    self.trigger.set()
                i = self.q.get()
                self.waiting -= 1
                if i == ThreadPoolDie:
                    return
                self.handler(i)
        finally:
            self.running -= 1
            if self.waiting == 0:
                self.trigger.set()
            if self.running == 0:
                self.finished.set()

    def __call__(self):
        self.q = Queue()
        self.trigger = Event()
        self.finished = Event()
        self.waiting = 0
        self.running = 0
        try:
            while True:
                for x in xrange(self.concurrency - self.running):
                    self.running += 1
                    fork(self.handler_wrap)

                if self.waiting == 0:
                    self.trigger.wait()
                    self.trigger.clear()

                try:
                    n = self.generator()
                except StopIteration:
                    break

                self.q.put(n)
                sleep()
        finally:
            for x in xrange(self.concurrency):
                self.q.put(ThreadPoolDie)
            if self.finalizer:
                self.finished.wait()
                fork(self.finalizer)
예제 #7
0
파일: pool.py 프로젝트: HVF/diesel
class ThreadPool(object):
    def __init__(self, concurrency, handler, generator, finalizer=None):
        self.concurrency = concurrency
        self.handler = handler
        self.generator = generator
        self.finalizer = finalizer

    def handler_wrap(self):
        try:
            label("thread-pool-%s" % self.handler)
            while True:
                self.waiting += 1
                if self.waiting == 1:
                    self.trigger.set()
                i = self.q.get()
                self.waiting -= 1
                if i == ThreadPoolDie:
                    return
                self.handler(i)
        finally:
            self.running -=1
            if self.waiting == 0:
                self.trigger.set()
            if self.running == 0:
                self.finished.set()

    def __call__(self):
        self.q = Queue()
        self.trigger = Event()
        self.finished = Event()
        self.waiting = 0
        self.running = 0
        try:
            while True:
                for x in xrange(self.concurrency - self.running):
                    self.running += 1
                    fork(self.handler_wrap)

                if self.waiting == 0:
                    self.trigger.wait()
                    self.trigger.clear()

                try:
                    n = self.generator()
                except StopIteration:
                    break

                self.q.put(n)
                sleep()
        finally:
            for x in xrange(self.concurrency):
                self.q.put(ThreadPoolDie)
            if self.finalizer:
                self.finished.wait()
                fork(self.finalizer)
예제 #8
0
파일: websockets.py 프로젝트: yadra/diesel
    def websocket_protocol(self, req):
        """Runs the WebSocket protocol after the handshake is complete.

        Creates two `Queue` instances for incoming and outgoing messages and
        passes them to the `web_socket_handler` that was supplied to the
        `WebSocketServer` constructor.

        """
        inq = Queue()
        outq = Queue()

        if req.rfc_handshake:
            handle_frames = self.handle_rfc_6455_frames
        else:
            # Finish the non-RFC handshake
            key1 = req.headers.get('Sec-WebSocket-Key1')
            key2 = req.headers.get('Sec-WebSocket-Key2')

            # The final key can be in two places. The first is in the
            # `Request.data` attribute if diesel is *not* being proxied
            # to by a smart proxy that parsed HTTP requests. If it is being
            # proxied to, that data will not have been sent until after our
            # initial 101 Switching Protocols response, so we will need to
            # receive it here.

            if req.data:
                key3 = req.data
            else:
                evt, key3 = first(receive=8, sleep=5)
                assert evt == "receive", "timed out while finishing handshake"

            num1 = int(''.join(c for c in key1 if c in '0123456789'))
            num2 = int(''.join(c for c in key2 if c in '0123456789'))
            assert num1 % key1.count(' ') == 0
            assert num2 % key2.count(' ') == 0
            final = pack('!II8s', num1 / key1.count(' '),
                         num2 / key2.count(' '), key3)
            handshake_finish = hashlib.md5(final).digest()
            send(handshake_finish)

            handle_frames = self.handle_non_rfc_frames

        def wrap(req, inq, outq):
            self.web_socket_handler(req, inq, outq)
            outq.put(WebSocketDisconnect())

        handler_loop = fork(wrap, req, inq, outq)

        try:
            handle_frames(inq, outq)
        except ConnectionClosed:
            if handler_loop.running:
                inq.put(WebSocketDisconnect())
            raise
예제 #9
0
파일: websockets.py 프로젝트: 1angxi/diesel
    def websocket_protocol(self, req):
        """Runs the WebSocket protocol after the handshake is complete.

        Creates two `Queue` instances for incoming and outgoing messages and
        passes them to the `web_socket_handler` that was supplied to the
        `WebSocketServer` constructor.

        """
        inq = Queue()
        outq = Queue()

        if req.rfc_handshake:
            handle_frames = self.handle_rfc_6455_frames
        else:
            # Finish the non-RFC handshake
            key1 = req.headers.get('Sec-WebSocket-Key1')
            key2 = req.headers.get('Sec-WebSocket-Key2')

            # The final key can be in two places. The first is in the
            # `Request.data` attribute if diesel is *not* being proxied
            # to by a smart proxy that parsed HTTP requests. If it is being
            # proxied to, that data will not have been sent until after our
            # initial 101 Switching Protocols response, so we will need to
            # receive it here.

            if req.data:
                key3 = req.data
            else:
                evt, key3 = first(receive=8, sleep=5)
                assert evt == "receive", "timed out while finishing handshake"

            num1 = int(''.join(c for c in key1 if c in '0123456789'))
            num2 = int(''.join(c for c in key2 if c in '0123456789'))
            assert num1 % key1.count(' ') == 0
            assert num2 % key2.count(' ') == 0
            final = pack('!II8s', num1 / key1.count(' '), num2 / key2.count(' '), key3)
            handshake_finish = hashlib.md5(final).digest()
            send(handshake_finish)

            handle_frames = self.handle_non_rfc_frames

        def wrap(req, inq, outq):
            self.web_socket_handler(req, inq, outq)
            outq.put(WebSocketDisconnect())

        handler_loop = fork(wrap, req, inq, outq)

        try:
            handle_frames(inq, outq)
        except ConnectionClosed:
            if handler_loop.running:
                inq.put(WebSocketDisconnect())
            raise
예제 #10
0
def test_pending_events_dont_break_ordering_when_handling_early_values():

    # This test confirms that "early values" returned from a Waiter do
    # not give other pending event sources the chance to switch their
    # values into the greenlet while it context switches to give other
    # greenlets a chance to run.

    # First we setup a fake connection. It mimics a connection that does
    # not have data waiting in the buffer, and has to wait for the system
    # to call it back when data is ready on the socket. The delay argument
    # specifies how long the test should wait before simulating that data
    # is ready.

    conn1 = FakeConnection(1, delay=[None, 0.1])

    # Next we setup a Queue instance and prime it with a value, so it will
    # be ready early and return an EarlyValue.

    q = Queue()
    q.put(1)

    # Force our fake connection into the connection stack for the current
    # loop so we can make network calls (like until_eol).

    loop = core.current_loop
    loop.connection_stack.append(conn1)

    try:

        # OK, this first() call does two things.
        # 1) It calls until_eol, finds that no data is ready, and sets up a
        #    callback to be triggered when data is ready (which our
        #    FakeConnection will simulate).
        # 2) Fetches from the 'q' which will result in an EarlyValue.

        source, value = diesel.first(until_eol=True, waits=[q])
        assert source == q, source

        # What must happen is that the callback registered to handle data
        # from the FakeConnection when it arrives MUST BE CANCELED/DISCARDED/
        # FORGOTTEN/NEVER CALLED. If it gets called, it will muck with
        # internal state, and possibly switch back into the running greenlet
        # with an unexpected value, which will throw off the ordering of
        # internal state and basically break everything.

        v = diesel.until_eol()
        assert v == 'expected value 1\r\n', 'actual value == %r !!!' % (v,)

    finally:
        loop.connection_stack = []
예제 #11
0
def test_pending_events_dont_break_ordering_when_handling_early_values():

    # This test confirms that "early values" returned from a Waiter do
    # not give other pending event sources the chance to switch their
    # values into the greenlet while it context switches to give other
    # greenlets a chance to run.

    # First we setup a fake connection. It mimics a connection that does
    # not have data waiting in the buffer, and has to wait for the system
    # to call it back when data is ready on the socket. The delay argument
    # specifies how long the test should wait before simulating that data
    # is ready.

    conn1 = FakeConnection(1, delay=[None, 0.1])

    # Next we setup a Queue instance and prime it with a value, so it will
    # be ready early and return an EarlyValue.

    q = Queue()
    q.put(1)

    # Force our fake connection into the connection stack for the current
    # loop so we can make network calls (like until_eol).

    loop = core.current_loop
    loop.connection_stack.append(conn1)

    try:

        # OK, this first() call does two things.
        # 1) It calls until_eol, finds that no data is ready, and sets up a
        #    callback to be triggered when data is ready (which our
        #    FakeConnection will simulate).
        # 2) Fetches from the 'q' which will result in an EarlyValue.

        source, value = diesel.first(until_eol=True, waits=[q])
        assert source == q, source

        # What must happen is that the callback registered to handle data
        # from the FakeConnection when it arrives MUST BE CANCELED/DISCARDED/
        # FORGOTTEN/NEVER CALLED. If it gets called, it will muck with
        # internal state, and possibly switch back into the running greenlet
        # with an unexpected value, which will throw off the ordering of
        # internal state and basically break everything.

        v = diesel.until_eol()
        assert v == 'expected value 1\r\n', 'actual value == %r !!!' % (v, )

    finally:
        loop.connection_stack = []
예제 #12
0
class ProcessPool(object):
    """A bounded pool of subprocesses.

    An instance is callable, just like a Process, and will return the result
    of executing the function in a subprocess. If all subprocesses are busy,
    the caller will wait in a queue.

    """
    def __init__(self, concurrency, handler):
        """Creates a new ProcessPool with subprocesses that run the handler.

        Args:
            concurrency (int): The number of subprocesses to spawn.
            handler (callable): A callable that the subprocesses will execute.

        """
        self.concurrency = concurrency
        self.handler = handler
        self.available_procs = Queue()
        self.all_procs = []

    def __call__(self, *args, **params):
        """Gets a process from the pool, executes it, and returns the results.

        This call will block until there is a process available to handle it.

        """
        if not self.all_procs:
            raise NoSubProcesses("Did you forget to start the pool?")
        try:
            p = self.available_procs.get()
            result = p(*args, **params)
            return result
        finally:
            self.available_procs.put(p)

    def pool(self):
        """A callable that starts the processes in the pool.

        This is useful as the callable to pass to a diesel.Loop when adding a
        ProcessPool to your application.

        """
        for i in xrange(self.concurrency):
            proc = spawn(self.handler)
            self.available_procs.put(proc)
            self.all_procs.append(proc)
예제 #13
0
파일: process.py 프로젝트: 1angxi/diesel
class ProcessPool(object):
    """A bounded pool of subprocesses.

    An instance is callable, just like a Process, and will return the result
    of executing the function in a subprocess. If all subprocesses are busy,
    the caller will wait in a queue.

    """
    def __init__(self, concurrency, handler):
        """Creates a new ProcessPool with subprocesses that run the handler.

        Args:
            concurrency (int): The number of subprocesses to spawn.
            handler (callable): A callable that the subprocesses will execute.

        """
        self.concurrency = concurrency
        self.handler = handler
        self.available_procs = Queue()
        self.all_procs = []

    def __call__(self, *args, **params):
        """Gets a process from the pool, executes it, and returns the results.

        This call will block until there is a process available to handle it.

        """
        if not self.all_procs:
            raise NoSubProcesses("Did you forget to start the pool?")
        try:
            p = self.available_procs.get()
            result = p(*args, **params)
            return result
        finally:
            self.available_procs.put(p)

    def pool(self):
        """A callable that starts the processes in the pool.

        This is useful as the callable to pass to a diesel.Loop when adding a
        ProcessPool to your application.

        """
        for i in xrange(self.concurrency):
            proc = spawn(self.handler)
            self.available_procs.put(proc)
            self.all_procs.append(proc)
예제 #14
0
파일: riak.py 프로젝트: 1angxi/diesel
    def get_many(self, keys, concurrency_limit=100, no_failures=False):
        assert self.used_client_context, "Cannot fetch in parallel without a pooled make_client_context!"
        inq = Queue()
        outq = Queue()
        for k in keys:
            inq.put(k)

        for x in xrange(min(len(keys), concurrency_limit)):
            diesel.fork(self._subrequest, inq, outq)

        failure = False
        okay, err = [], []
        for k in keys:
            (key, success, val) = outq.get()
            if success:
                okay.append((key, val))
            else:
                err.append((key, val))

        if no_failures:
            raise BucketSubrequestException("Error in parallel subrequests", err)
        return okay, err
예제 #15
0
    def get_many(self, keys, concurrency_limit=100, no_failures=False):
        assert self.used_client_context,\
        "Cannot fetch in parallel without a pooled make_client_context!"
        inq = Queue()
        outq = Queue()
        for k in keys:
            inq.put(k)

        for x in xrange(min(len(keys), concurrency_limit)):
            diesel.fork(self._subrequest, inq, outq)

        failure = False
        okay, err = [], []
        for k in keys:
            (key, success, val) = outq.get()
            if success:
                okay.append((key, val))
            else:
                err.append((key, val))

        if no_failures:
            raise BucketSubrequestException("Error in parallel subrequests",
                                            err)
        return okay, err
예제 #16
0
    def do_upgrade(self, req):
        if req.headers.get_one('Upgrade') != 'WebSocket':
            return self.web_handler(req)

        # do upgrade response
        org = req.headers.get_one('Origin')

        send(
'''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
WebSocket-Origin: %s\r
WebSocket-Location: %s\r
WebSocket-Protocol: diesel-generic\r
\r
''' % (org, self.ws_location))
        
        inq = Queue()
        outq = Queue()

        def wrap(inq, outq):
            self.web_socket_handler(inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, inq, outq)
                                    
        while True:
            try:
                typ, val = first(receive=1, waits=[outq.wait_id])
                if typ == 'receive':
                    assert val == '\x00'
                    val = until('\xff')[:-1]
                    if val == '':
                        inq.put(WebSocketDisconnect())
                    else:
                        data = dict((k, v[0]) if len(v) == 1 else (k, v) for k, v in cgi.parse_qs(val).iteritems())
                        inq.put(WebSocketData(data))
                else:
                    try:
                        v = outq.get(waiting=False)
                    except QueueEmpty:
                        pass
                    else:
                        if type(v) is WebSocketDisconnect:
                            send('\x00\xff')
                            break
                        else:
                            data = dumps(dict(v))
                            send('\x00%s\xff' % data)

            except ConnectionClosed:
                inq.put(WebSocketDisconnect())
                raise ConnectionClosed("remote disconnected")
예제 #17
0
파일: __init__.py 프로젝트: dowski/aspen
def handle(request):
    """Handle a request for a websocket.
    """
    if request.transport != 'xhr-polling':
        raise Response(404)

    org = request.headers.one('Origin')
    inq = Queue()
    outq = Queue()

    def wrap(request, inq, outq):
        handler(request, inq, outq)
        outq.put(WebSocketDisconnect())

    fork(wrap, request, inq, outq)

    while True:
        try:
            log.debug("trying websocket thing")
            typ, val = first(receive=1, waits=[outq.wait_id])
            log.debug(typ)
            log.debug(val)
            if typ == 'receive':
                assert val == '\x00'
                val = until('\xff')[:-1]
                if val == '':
                    inq.put(WebSocketDisconnect())
                else:
                    inq.put(request)
            else:
                try:
                    v = outq.get(waiting=False)
                except QueueEmpty:
                    pass
                else:
                    if type(v) is WebSocketDisconnect:
                        send('\x00\xff')
                        break
                    else:
                        send('\x00%s\xff' % response.to_http(request.version))

        except ConnectionClosed:
            inq.put(WebSocketDisconnect())
            raise ConnectionClosed("remote disconnected")
예제 #18
0
파일: __init__.py 프로젝트: dowski/aspen
def handle(request):
    """Handle a request for a websocket.
    """
    if request.transport != 'xhr-polling':
        raise Response(404)

    org = request.headers.one('Origin')
    inq = Queue()
    outq = Queue()

    def wrap(request, inq, outq):
        handler(request, inq, outq)
        outq.put(WebSocketDisconnect())
    fork(wrap, request, inq, outq)

    while True:
        try:
            log.debug("trying websocket thing")
            typ, val = first(receive=1, waits=[outq.wait_id])
            log.debug(typ)
            log.debug(val)
            if typ == 'receive':
                assert val == '\x00'
                val = until('\xff')[:-1]
                if val == '':
                    inq.put(WebSocketDisconnect())
                else:
                    inq.put(request)
            else:
                try:
                    v = outq.get(waiting=False)
                except QueueEmpty:
                    pass
                else:
                    if type(v) is WebSocketDisconnect:
                        send('\x00\xff')
                        break
                    else:
                        send('\x00%s\xff' % response.to_http(request.version))

        except ConnectionClosed:
            inq.put(WebSocketDisconnect())
            raise ConnectionClosed("remote disconnected")
예제 #19
0
파일: nitro.py 프로젝트: 1angxi/diesel
class DieselNitroService(object):
    """A Nitro service that can handle multiple clients.

    Clients must maintain a steady flow of messages in order to maintain
    state in the service. A heartbeat of some sort. Or the timeout can be
    set to a sufficiently large value understanding that it will cause more
    resource consumption.

    """
    name = ''
    default_log_level = loglevels.DEBUG
    timeout = 10

    def __init__(self, uri, logger=None, log_level=None):
        self.uri = uri
        self.nitro_socket = None
        self.log = logger or None
        self.selected_log_level = log_level
        self.clients = {}
        self.outgoing = Queue()
        self.incoming = Queue()
        self.name = self.name or self.__class__.__name__
        self._incoming_loop = None

        # Allow for custom `should_run` properties in subclasses.
        try:
            self.should_run = True
        except AttributeError:
            # A custom `should_run` property exists.
            pass

        if self.log and self.selected_log_level is not None:
            self.selected_log_level = None
            warnings.warn(
                "ignored `log_level` argument since `logger` was provided.",
                RuntimeWarning,
                stacklevel=2,
            )

    def _create_server_socket(self):
        self.nitro_socket = DieselNitroSocket(bind=self.uri)

    def _setup_the_logging_system(self):
        if not self.log:
            if self.selected_log_level is not None:
                log_level = self.selected_log_level
            else:
                log_level = self.default_log_level
            log_name = self.name or self.__class__.__name__
            self.log = log.name(log_name)
            self.log.min_level = log_level

    def _handle_client_requests_and_responses(self, remote_client):
        assert self.nitro_socket
        queues = [remote_client.incoming]
        try:
            while True:
                (evt, value) = diesel.first(waits=queues, sleep=self.timeout)
                if evt is remote_client.incoming:
                    assert isinstance(value, Message)
                    remote_client.async_frame = value.orig_frame
                    resp = self.handle_client_packet(value.data, remote_client.context)
                    if resp:
                        if isinstance(resp, basestring):
                            output = [resp]
                        else:
                            output = iter(resp)
                        for part in output:
                            msg = Message(
                                value.orig_frame,
                                remote_client.identity,
                                self.serialize_message(remote_client.identity, part),
                            )
                            self.outgoing.put(msg)
                elif evt == 'sleep':
                    break
        finally:
            self._cleanup_client(remote_client)

    def _cleanup_client(self, remote_client):
        del self.clients[remote_client.identity]
        self.cleanup_client(remote_client)
        self.log.debug("cleaned up client %r" % remote_client.identity)

    def _handle_all_inbound_and_outbound_traffic(self):
        assert self.nitro_socket
        queues = [self.nitro_socket, self.outgoing]
        socket = self.nitro_socket
        make_frame = pynitro.NitroFrame
        while self.should_run:
            (queue, msg) = diesel.first(waits=queues)

            if queue is self.outgoing:
                socket.reply(msg.orig_frame, make_frame(msg.data))
            else:
                id, obj = self.parse_message(msg.data)
                msg.clear_data()
                msg = Message(msg, id, obj)
                if msg.identity not in self.clients:
                    self._register_client(msg)
                self.clients[msg.identity].incoming.put(msg)


    def _register_client(self, msg):
        remote = RemoteClient.from_message(msg)
        self.clients[msg.identity] = remote
        self.register_client(remote, msg)
        diesel.fork_child(self._handle_client_requests_and_responses, remote)

    # Public API
    # ==========

    def __call__(self):
        return self.run()

    def run(self):
        self._create_server_socket()
        self._setup_the_logging_system()
        self._handle_all_inbound_and_outbound_traffic()

    def handle_client_packet(self, packet, context):
        """Called with a bytestring packet and dictionary context.

        Return an iterable of bytestrings.

        """
        raise NotImplementedError()

    def cleanup_client(self, remote_client):
        """Called with a RemoteClient instance. Do any cleanup you need to."""
        pass

    def register_client(self, remote_client, msg):
        """Called with a RemoteClient instance. Do any registration here."""
        pass

    def parse_message(self, raw_data):
        """Subclasses can override to alter the handling of inbound data.

        Transform an incoming bytestring into a structure (aka, json.loads)
        """
        return None, raw_data

    def serialize_message(self, identity, raw_data):
        """Subclasses can override to alter the handling of outbound data.

        Turn some structure into a bytestring (aka, json.dumps)
        """
        return raw_data

    def async_send(self, identity, msg):
        """Raises KeyError if client is no longer connected.
        """
        remote_client = self.clients[identity]
        out = self.serialize_message(msg)
        self.outgoing.put(
            Message(
                remote_client.async_frame,
                identity,
                out))
예제 #20
0
class Convoy(object):
    def __init__(self):
        self.routes = defaultdict(set)  # message name to host
        self.local_handlers = {}
        self.enabled_handlers = {}
        self.classes = {}
        self.host_queues = {}
        self.run_nameserver = None
        self.role_messages = defaultdict(list)
        self.roles = set()
        self.roles_wanted = set()
        self.roles_owned = set()
        self.role_clocks = {}
        self.role_by_name = {}
        self.incoming = Queue()
        self.pending = {}
        self.rpc_waits = {}
        self.table_changes = Queue()

    def run_with_nameserver(self, myns, nameservers, *objs):
        self.run_nameserver = myns
        self.run(nameservers, *objs)

    def run(self, nameservers, *objs):
        nameservers = [(h, int(p))
                       for h, p in (i.split(':') for i in nameservers)]
        runem = []
        if self.run_nameserver:
            runem.append(
                Thunk(lambda: run_consensus_server(self.run_nameserver,
                                                   nameservers)))
        runem.append(self)
        handler_functions = dict(
            (v, k) for k, v in self.local_handlers.iteritems())
        final_o = []
        for o in objs:
            if type(o.__class__) is ConvoyRegistrar:
                r = o.__class__
                self.roles_wanted.add(r)
                for m in self.role_messages[r]:
                    assert m not in self.local_handlers, \
                        "cannot add two instances for same role/message"
                    self.local_handlers[m] = \
                        getattr(o, 'handle_' + m)
            else:
                final_o.append(o)

        self.ns = ConvoyNameService(nameservers)
        runem.append(self.ns)
        runem.append(self.deliver)

        runem.extend(final_o)
        runem.append(ConvoyService())
        quickstart(*runem)

    def __call__(self):
        assert me.id
        should_process = self.roles
        rlog = log.sublog("convoy-resolver", LOGLVL_DEBUG)
        while True:
            for r in should_process:
                if r in self.roles_wanted:
                    resp = self.ns.add(r.name(), me.id, r.limit)
                    ans = None
                    if type(resp) == ConsensusSet:
                        self.roles_owned.add(r)
                        ans = resp
                    else:
                        if r in self.roles_owned:
                            self.roles_owned.remove(r)
                        if resp.set:
                            ans = resp.set
                else:
                    ans = self.ns.lookup(r.name())

                if ans:
                    self.role_clocks[r.name()] = ans.clock
                    for m in self.role_messages[r]:
                        self.routes[m] = ans.members

            if should_process:
                self.log_resolution_table(rlog, should_process)
                self.table_changes.put(None)
            wait_result = self.ns.wait(5, self.role_clocks)
            if type(wait_result) == ConvoyWaitDone:
                should_process = set([self.role_by_name[wait_result.key]])
            else:
                should_process = set()
            self.ns.alive()

    def log_resolution_table(self, rlog, processed):
        rlog.debug("======== diesel/convoy routing table updates ========")
        rlog.debug("  ")
        for p in processed:
            rlog.debug("   %s [%s]" %
                       (p.name(), ', '.join(self.role_messages[p])))
            if self.role_messages:
                hosts = self.routes[self.role_messages[p][0]]
                for h in hosts:
                    rlog.debug("     %s %s" % ('*' if h == me.id else '-', h))

    def register(self, mod):
        for name in dir(mod):
            v = getattr(mod, name)
            if type(v) is type and issubclass(v, ProtoBase):
                self.classes[v.__name__] = v

    def add_target_role(self, o):
        self.roles.add(o)
        self.role_by_name[o.name()] = o
        for k, v in o.__dict__.iteritems():
            if k.startswith("handle_") and callable(v):
                handler_for = k.split("_", 1)[-1]
                assert handler_for in self.classes, "protobuf class not recognized; register() the module"
                self.role_messages[o].append(handler_for)

    def host_specific_send(self, host, msg, typ, transport_cb):
        if host not in self.host_queues:
            q = Queue()
            fork(host_loop, host, q)
            self.host_queues[host] = q

        self.host_queues[host].put((msg, typ, transport_cb))

    def local_dispatch(self, env):
        if env.type not in self.classes:
            self.host_specific_send(
                env.node_id,
                MessageResponse(in_response_to=env.req_id,
                                result=MessageResponse.REFUSED,
                                error_message="cannot handle type"),
                MESSAGE_RES, None)
        elif me.id not in self.routes[env.type]:
            # use routes, balance, etc
            self.host_specific_send(
                env.node_id,
                MessageResponse(in_response_to=env.req_id,
                                delivered=MessageResponse.REFUSED,
                                error_message="do not own route"), MESSAGE_RES,
                None)
        else:
            inst = self.classes[env.type](env.body)
            r = self.local_handlers[env.type]
            sender = ConvoySender(env)
            back = MessageResponse(in_response_to=env.req_id,
                                   delivered=MessageResponse.ACCEPTED)

            self.host_specific_send(env.node_id, back, MESSAGE_RES, None)
            try:
                r(sender, inst)
            except Exception, e:
                s = str(e)
                back.result = MessageResponse.EXCEPTION
                back.error_message = s
                raise
            else:
예제 #21
0
    def do_upgrade(self, req):
        if req.headers.get_one('Upgrade') != 'websocket' and req.headers.get_one('Upgrade') != 'WebSocket':
            return self.web_handler(req)

        hybi = False

        # do upgrade response
        org = req.headers.get_one('Origin')
        if 'Sec-WebSocket-Key' in req.headers:
            assert req.headers.get_one('Sec-WebSocket-Version') == '8', \
                   "We currently only support version 8 and below"

            protocol = (req.headers.get_one('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key = req.headers.get_one('Sec-WebSocket-Key')
            accept = b64encode(hashlib.sha1(key + self.GUID).digest())
            send(server_handshake_hybi % accept)
            if protocol:
                send("Sec-WebSocket-Protocol: %s\r" % protocol)
            send("\r\n\r\n")
            hybi = True

        elif 'Sec-WebSocket-Key1' in req.headers:
            protocol = (req.headers.get_one('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key1 = req.headers.get_one('Sec-WebSocket-Key1')
            key2 = req.headers.get_one('Sec-WebSocket-Key2')
            key3 = receive(8)
            num1 = int(''.join(c for c in key1 if c in '0123456789'))
            num2 = int(''.join(c for c in key2 if c in '0123456789'))
            assert num1 % key1.count(' ') == 0
            assert num2 % key2.count(' ') == 0
            final = pack('!II8s', num1 / key1.count(' '), num2 / key2.count(' '), key3)
            secure_response = hashlib.md5(final).digest()
            send(
'''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
Sec-WebSocket-Origin: %s\r
Sec-WebSocket-Location: %s\r
'''% (org, self.ws_location))
            if protocol:
                send("Sec-WebSocket-Protocol: %s\r\n" % (protocol,))
            send("\r\n")
            send(secure_response)

        else:
            send(
'''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
WebSocket-Origin: %s\r
WebSocket-Location: %s\r
WebSocket-Protocol: diesel-generic\r
\r
''' % (org, self.ws_location))


        inq = Queue()
        outq = Queue()

        def wrap(inq, outq):
            self.web_socket_handler(inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, inq, outq)

        while True:
            try:
                if hybi:
                    typ, val = first(receive=2, waits=[outq.wait_id])
                    if typ == 'receive':
                        b1, b2 = unpack(">BB", val)

                        opcode = b1 & 0x0f
                        fin = (b1 & 0x80) >> 7
                        has_mask = (b2 & 0x80) >> 7

                        assert has_mask == 1, "Frames must be masked"

                        if opcode == 8:
                            inq.put(WebSocketDisconnect())
                        else:
                            assert opcode == 1, "Currently only opcode 1 is supported"
                            length = b2 & 0x7f
                            if length == 126:
                                length = unpack('>H', receive(2))
                            elif length == 127:
                                length = unpack('>L', receive(8))

                            mask = unpack('>BBBB', receive(4))

                            payload = array('>B', receive(length))
                            for i in xrange(len(payload)):
                                payload[i] ^= mask[i % 4]

                            data = dict((k, v[0]) if len(v) == 1 else (k, v) for k, v in cgi.parse_qs(payload.tostring()).iteritems())
                            inq.put(WebSocketData(data))
                    else:
                        try:
                            v = outq.get(waiting=False)
                        except QueueEmpty:
                            pass
                        else:
                            if type(v) is WebSocketDisconnect:
                                b1 = 0x80 | (8 & 0x0f) # FIN + opcode
                                send(pack('>BB', b1, 0))
                                break
                            else:
                                payload = dumps(v)

                                b1 = 0x80 | (1 & 0x0f) # FIN + opcode

                                payload_len = len(payload)
                                if payload_len <= 125:
                                    header = pack('>BB', b1, payload_len)
                                elif payload_len > 125 and payload_len < 65536:
                                    header = pack('>BBH', b1, 126, payload_len)
                                elif payload_len >= 65536:
                                    header = pack('>BBQ', b1, 127, payload_len)

                            send(header + payload)
                else:
                    typ, val = first(receive=1, waits=[outq.wait_id])
                    if typ == 'receive':
                        assert val == '\x00'
                        val = until('\xff')[:-1]
                        if val == '':
                            inq.put(WebSocketDisconnect())
                        else:
                            data = dict((k, v[0]) if len(v) == 1 else (k, v) for k, v in cgi.parse_qs(val).iteritems())
                            inq.put(WebSocketData(data))
                    else:
                        try:
                            v = outq.get(waiting=False)
                        except QueueEmpty:
                            pass
                        else:
                            if type(v) is WebSocketDisconnect:
                                send('\x00\xff')
                                break
                            else:
                                data = dumps(dict(v))
                                send('\x00%s\xff' % data)


            except ConnectionClosed:
                inq.put(WebSocketDisconnect())
                raise ConnectionClosed("remote disconnected")
예제 #22
0
    def do_upgrade(self, req):
        if req.headers.get_one('Upgrade') != 'WebSocket':
            return self.web_handler(req)

        # do upgrade response
        org = req.headers.get_one('Origin')
        if 'Sec-WebSocket-Key1' in req.headers:
            protocol = (req.headers.get_one('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key1 = req.headers.get_one('Sec-WebSocket-Key1')
            key2 = req.headers.get_one('Sec-WebSocket-Key2')
            key3 = receive(8)
            num1 = int(''.join(c for c in key1 if c in '0123456789'))
            num2 = int(''.join(c for c in key2 if c in '0123456789'))
            assert num1 % key1.count(' ') == 0
            assert num2 % key2.count(' ') == 0
            final = pack('!II8s', num1 / key1.count(' '),
                         num2 / key2.count(' '), key3)
            secure_response = hashlib.md5(final).digest()
            send('''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
Sec-WebSocket-Origin: %s\r
Sec-WebSocket-Location: %s\r
''' % (org, self.ws_location))
            if protocol:
                send("Sec-WebSocket-Protocol: %s\r\n" % (protocol, ))
            send("\r\n")
            send(secure_response)
        else:
            send('''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
WebSocket-Origin: %s\r
WebSocket-Location: %s\r
WebSocket-Protocol: diesel-generic\r
\r
''' % (org, self.ws_location))

        inq = Queue()
        outq = Queue()

        def wrap(inq, outq):
            self.web_socket_handler(inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, inq, outq)

        while True:
            try:
                typ, val = first(receive=1, waits=[outq.wait_id])
                if typ == 'receive':
                    assert val == '\x00'
                    val = until('\xff')[:-1]
                    if val == '':
                        inq.put(WebSocketDisconnect())
                    else:
                        data = dict((k, v[0]) if len(v) == 1 else (k, v)
                                    for k, v in cgi.parse_qs(val).iteritems())
                        inq.put(WebSocketData(data))
                else:
                    try:
                        v = outq.get(waiting=False)
                    except QueueEmpty:
                        pass
                    else:
                        if type(v) is WebSocketDisconnect:
                            send('\x00\xff')
                            break
                        else:
                            data = dumps(dict(v))
                            send('\x00%s\xff' % data)

            except ConnectionClosed:
                inq.put(WebSocketDisconnect())
                raise ConnectionClosed("remote disconnected")
예제 #23
0
    def do_upgrade(self, req):
        if req.headers.get('Upgrade', '').lower() != 'websocket':
            return self.web_handler(req)

        headers = {}

        # do upgrade response
        org = req.headers.get('Origin')
        handshake_finish = None
        if 'Sec-WebSocket-Key' in req.headers:
            assert req.headers.get('Sec-WebSocket-Version') in ['8', '13'], \
                   "We currently only support Websockets version 8 and 13 (ver=%s)" % \
                   req.headers.get('Sec-WebSocket-Version')

            protocol = req.headers.get('Sec-WebSocket-Protocol', None)
            key = req.headers.get('Sec-WebSocket-Key')
            accept = b64encode(hashlib.sha1(key + self.GUID).digest())
            headers = {
                'Upgrade' : 'websocket',
                'Connection' : 'Upgrade',
                'Sec-WebSocket-Accept' : accept,
                }
        elif 'Sec-WebSocket-Key1' in req.headers:
            protocol = req.headers.get('Sec-WebSocket-Protocol', None)
            key1 = req.headers.get('Sec-WebSocket-Key1')
            key2 = req.headers.get('Sec-WebSocket-Key2')
            headers = {
                'Upgrade': 'WebSocket',
                'Connection': 'Upgrade',
                'Sec-WebSocket-Origin': org,
                'Sec-WebSocket-Location': req.url.replace('http', 'ws', 1),
            }
            key3 = req.data
            assert len(key3) == 8, len(key3)
            num1 = int(''.join(c for c in key1 if c in '0123456789'))
            num2 = int(''.join(c for c in key2 if c in '0123456789'))
            assert num1 % key1.count(' ') == 0
            assert num2 % key2.count(' ') == 0
            final = pack('!II8s', num1 / key1.count(' '), num2 / key2.count(' '), key3)
            handshake_finish = hashlib.md5(final).digest()
        else:
            assert 0, "Unsupported WebSocket handshake."

        if protocol:
            headers['Sec-WebSocket-Protocol'] = protocol

        resp = Response(
                response='' if not handshake_finish else handshake_finish,
                status=101,
                headers=headers,
                )

        self.send_response(resp)

        inq = Queue()
        outq = Queue()

        def wrap(req, inq, outq):
            self.web_socket_handler(req, inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, req._get_current_object(), inq, outq)

        if not handshake_finish:
            handle_frames = self.handle_rfc_6455_frames
        else:
            handle_frames = self.handle_non_rfc_frames

        try:
            handle_frames(inq, outq)
        except ConnectionClosed:
            inq.put(WebSocketDisconnect())
            raise
예제 #24
0
    def do_upgrade(self, req):
        if req.headers.get('Upgrade', '').lower() != 'websocket':
            return self.web_handler(req)

        headers = {}

        # do upgrade response
        org = req.headers.get('Origin')
        handshake_finish = None
        if 'Sec-WebSocket-Key' in req.headers:
            assert req.headers.get('Sec-WebSocket-Version') in ['8', '13'], \
                   "We currently only support Websockets version 8 and 13 (ver=%s)" % \
                   req.headers.get('Sec-WebSocket-Version')

            protocol = (req.headers.get('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key = req.headers.get('Sec-WebSocket-Key')
            accept = b64encode(hashlib.sha1(key + self.GUID).digest())
            headers = {
                'Upgrade': 'websocket',
                'Connection': 'Upgrade',
                'Sec-WebSocket-Accept': accept,
            }
            if protocol:
                headers["Sec-WebSocket-Protocol"] = protocol
        else:
            assert 0, "Only RFC 6455 WebSockets are supported"

        resp = Response(response='', status=101, headers=headers)

        self.send_response(resp)

        if handshake_finish:
            send(handshake_finish)

        inq = Queue()
        outq = Queue()

        def wrap(req, inq, outq):
            self.web_socket_handler(req, inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, req._get_current_object(), inq, outq)

        while True:
            try:
                typ, val = first(receive=2, waits=[outq])
                if typ == 'receive':
                    b1, b2 = unpack(">BB", val)

                    opcode = b1 & 0x0f
                    fin = (b1 & 0x80) >> 7
                    has_mask = (b2 & 0x80) >> 7

                    assert has_mask == 1, "Frames must be masked"

                    if opcode == 8:
                        inq.put(WebSocketDisconnect())
                    else:
                        assert opcode == 1, "Currently only opcode 1 is supported (opcode=%s)" % opcode
                        length = b2 & 0x7f
                        if length == 126:
                            length = unpack('>H', receive(2))[0]
                        elif length == 127:
                            length = unpack('>L', receive(8))[0]

                        mask = unpack('>BBBB', receive(4))
                        payload = array('B', receive(length))
                        for i in xrange(len(payload)):
                            payload[i] ^= mask[i % 4]

                        try:
                            data = loads(payload.tostring())
                            inq.put(data)
                        except JSONDecodeError:
                            pass
                elif typ == outq:
                    if type(val) is WebSocketDisconnect:
                        b1 = 0x80 | (8 & 0x0f)  # FIN + opcode
                        send(pack('>BB', b1, 0))
                        break
                    else:
                        payload = dumps(val)

                        b1 = 0x80 | (1 & 0x0f)  # FIN + opcode

                        payload_len = len(payload)
                        if payload_len <= 125:
                            header = pack('>BB', b1, payload_len)
                        elif payload_len > 125 and payload_len < 65536:
                            header = pack('>BBH', b1, 126, payload_len)
                        elif payload_len >= 65536:
                            header = pack('>BBQ', b1, 127, payload_len)

                    send(header + payload)

            except ConnectionClosed:
                inq.put(WebSocketDisconnect())
                raise ConnectionClosed("remote disconnected")
예제 #25
0
파일: websockets.py 프로젝트: wmoss/diesel
    def do_upgrade(self, req):
        if req.headers.get_one(
                'Upgrade') != 'websocket' and req.headers.get_one(
                    'Upgrade') != 'WebSocket':
            return self.web_handler(req)

        hybi = False

        # do upgrade response
        org = req.headers.get_one('Origin')
        if 'Sec-WebSocket-Key' in req.headers:
            assert req.headers.get_one('Sec-WebSocket-Version') in ['8', '13'], \
                   "We currently only support Websockets version 8 and 13 (ver=%s)" % \
                   req.headers.get_one('Sec-WebSocket-Version')

            protocol = (req.headers.get_one('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key = req.headers.get_one('Sec-WebSocket-Key')
            accept = b64encode(hashlib.sha1(key + self.GUID).digest())
            send(server_handshake_hybi % accept)
            if protocol:
                send("Sec-WebSocket-Protocol: %s\r" % protocol)
            send("\r\n\r\n")
            hybi = True

        elif 'Sec-WebSocket-Key1' in req.headers:
            protocol = (req.headers.get_one('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key1 = req.headers.get_one('Sec-WebSocket-Key1')
            key2 = req.headers.get_one('Sec-WebSocket-Key2')
            key3 = receive(8)
            num1 = int(''.join(c for c in key1 if c in '0123456789'))
            num2 = int(''.join(c for c in key2 if c in '0123456789'))
            assert num1 % key1.count(' ') == 0
            assert num2 % key2.count(' ') == 0
            final = pack('!II8s', num1 / key1.count(' '),
                         num2 / key2.count(' '), key3)
            secure_response = hashlib.md5(final).digest()
            send('''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
Sec-WebSocket-Origin: %s\r
Sec-WebSocket-Location: %s\r
''' % (org, self.ws_location))
            if protocol:
                send("Sec-WebSocket-Protocol: %s\r\n" % (protocol, ))
            send("\r\n")
            send(secure_response)

        else:
            send('''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
WebSocket-Origin: %s\r
WebSocket-Location: %s\r
WebSocket-Protocol: diesel-generic\r
\r
''' % (org, self.ws_location))

        inq = Queue()
        outq = Queue()

        def wrap(req, inq, outq):
            self.web_socket_handler(req, inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, req, inq, outq)

        while True:
            try:
                if hybi:
                    typ, val = first(receive=2, waits=[outq])
                    if typ == 'receive':
                        b1, b2 = unpack(">BB", val)

                        opcode = b1 & 0x0f
                        fin = (b1 & 0x80) >> 7
                        has_mask = (b2 & 0x80) >> 7

                        assert has_mask == 1, "Frames must be masked"

                        if opcode == 8:
                            inq.put(WebSocketDisconnect())
                        else:
                            assert opcode == 1, "Currently only opcode 1 is supported (opcode=%s)" % opcode
                            length = b2 & 0x7f
                            if length == 126:
                                length = unpack('>H', receive(2))
                            elif length == 127:
                                length = unpack('>L', receive(8))

                            mask = unpack('>BBBB', receive(4))
                            payload = array('B', receive(length))
                            for i in xrange(len(payload)):
                                payload[i] ^= mask[i % 4]

                            try:
                                data = loads(payload.tostring())
                                inq.put(data)
                            except JSONDecodeError:
                                pass

                    elif typ == outq:
                        if type(val) is WebSocketDisconnect:
                            b1 = 0x80 | (8 & 0x0f)  # FIN + opcode
                            send(pack('>BB', b1, 0))
                            break
                        else:
                            payload = dumps(val)

                            b1 = 0x80 | (1 & 0x0f)  # FIN + opcode

                            payload_len = len(payload)
                            if payload_len <= 125:
                                header = pack('>BB', b1, payload_len)
                            elif payload_len > 125 and payload_len < 65536:
                                header = pack('>BBH', b1, 126, payload_len)
                            elif payload_len >= 65536:
                                header = pack('>BBQ', b1, 127, payload_len)

                        send(header + payload)
                else:
                    typ, val = first(receive=1, waits=[outq])
                    if typ == 'receive':
                        assert val == '\x00'
                        val = until('\xff')[:-1]
                        if val == '':
                            inq.put(WebSocketDisconnect())
                        else:
                            try:
                                data = loads(val)
                                inq.put(data)
                            except JSONDecodeError:
                                pass
                    elif typ == outq:
                        if type(val) is WebSocketDisconnect:
                            send('\x00\xff')
                            break
                        else:
                            data = dumps(dict(val))
                            send('\x00%s\xff' % data)

            except ConnectionClosed:
                inq.put(WebSocketDisconnect())
                raise ConnectionClosed("remote disconnected")
예제 #26
0
파일: websockets.py 프로젝트: dowski/aspen
    def do_upgrade(self, req):
        if req.headers.get_one('Upgrade') != 'WebSocket':
            return self.web_handler(req)

        # do upgrade response
        org = req.headers.get_one('Origin')
        if 'Sec-WebSocket-Key1' in req.headers:
            protocol = (req.headers.get_one('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key1 = req.headers.get_one('Sec-WebSocket-Key1')
            key2 = req.headers.get_one('Sec-WebSocket-Key2')
            key3 = receive(8)
            num1 = int(''.join(c for c in key1 if c in '0123456789'))
            num2 = int(''.join(c for c in key2 if c in '0123456789'))
            assert num1 % key1.count(' ') == 0
            assert num2 % key2.count(' ') == 0
            final = pack('!II8s', num1 / key1.count(' '), num2 / key2.count(' '), key3)
            secure_response = hashlib.md5(final).digest()
            send(
'''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
Sec-WebSocket-Origin: %s\r
Sec-WebSocket-Location: %s\r
'''% (org, self.ws_location))
            if protocol:
                send("Sec-WebSocket-Protocol: %s\r\n" % (protocol,))
            send("\r\n")
            send(secure_response)
        else:
            send(
'''HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
WebSocket-Origin: %s\r
WebSocket-Location: %s\r
WebSocket-Protocol: diesel-generic\r
\r
''' % (org, self.ws_location))
        
        inq = Queue()
        outq = Queue()

        def wrap(inq, outq):
            self.web_socket_handler(inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, inq, outq)
                                    
        while True:
            try:
                typ, val = first(receive=1, waits=[outq.wait_id])
                if typ == 'receive':
                    assert val == '\x00'
                    val = until('\xff')[:-1]
                    if val == '':
                        inq.put(WebSocketDisconnect())
                    else:
                        data = dict((k, v[0]) if len(v) == 1 else (k, v) for k, v in cgi.parse_qs(val).iteritems())
                        inq.put(WebSocketData(data))
                else:
                    try:
                        v = outq.get(waiting=False)
                    except QueueEmpty:
                        pass
                    else:
                        if type(v) is WebSocketDisconnect:
                            send('\x00\xff')
                            break
                        else:
                            data = dumps(dict(v))
                            send('\x00%s\xff' % data)

            except ConnectionClosed:
                inq.put(WebSocketDisconnect())
                raise ConnectionClosed("remote disconnected")
예제 #27
0
class Convoy(object):
    def __init__(self):
        self.routes = defaultdict(set) # message name to host
        self.local_handlers = {}
        self.enabled_handlers = {}
        self.classes = {}
        self.host_queues = {}
        self.run_nameserver = None
        self.role_messages = defaultdict(list)
        self.roles = set()
        self.roles_wanted = set()
        self.roles_owned = set()
        self.role_clocks = {}
        self.role_by_name = {}
        self.incoming = Queue()
        self.pending = {}
        self.rpc_waits = {}
        self.table_changes = Queue()

    def run_with_nameserver(self,  myns, nameservers, *objs):
        self.run_nameserver = myns
        self.run(nameservers, *objs)

    def run(self, nameservers, *objs):
        nameservers = [(h, int(p)) 
            for h, p in (i.split(':')
            for i in nameservers)]
        runem = []
        if self.run_nameserver:
            runem.append(
                Thunk(lambda: run_consensus_server(self.run_nameserver, nameservers)))
        runem.append(self)
        handler_functions = dict((v, k) for k, v in self.local_handlers.iteritems())
        final_o = []
        for o in objs:
            if type(o.__class__) is ConvoyRegistrar:
                r = o.__class__
                self.roles_wanted.add(r)
                for m in self.role_messages[r]:
                    assert m not in self.local_handlers, \
                        "cannot add two instances for same role/message"
                    self.local_handlers[m] = \
                        getattr(o, 'handle_' + m)
            else:
                final_o.append(o)

        self.ns = ConvoyNameService(nameservers)
        runem.append(self.ns)
        runem.append(self.deliver)

        runem.extend(final_o)
        runem.append(ConvoyService())
        quickstart(*runem)

    def __call__(self):
        assert me.id
        should_process = self.roles
        rlog = log.sublog("convoy-resolver", LOGLVL_DEBUG)
        while True:
            for r in should_process:
                if r in self.roles_wanted:
                    resp = self.ns.add(r.name(), me.id, r.limit)
                    ans = None
                    if type(resp) == ConsensusSet:
                        self.roles_owned.add(r)
                        ans = resp
                    else:
                        if r in self.roles_owned:
                            self.roles_owned.remove(r)
                        if resp.set:
                            ans = resp.set
                else:
                    ans = self.ns.lookup(r.name())

                if ans:
                    self.role_clocks[r.name()] = ans.clock
                    for m in self.role_messages[r]:
                        self.routes[m] = ans.members

            if should_process:
                self.log_resolution_table(rlog, should_process)
                self.table_changes.put(None)
            wait_result = self.ns.wait(5, self.role_clocks)
            if type(wait_result) == ConvoyWaitDone:
                should_process = set([self.role_by_name[wait_result.key]])
            else:
                should_process = set()
            self.ns.alive()

    def log_resolution_table(self, rlog, processed):
        rlog.debug("======== diesel/convoy routing table updates ========")
        rlog.debug("  ")
        for p in processed:
            rlog.debug("   %s [%s]" %
                    (p.name(),
                    ', '.join(self.role_messages[p])))
            if self.role_messages:
                hosts = self.routes[self.role_messages[p][0]]
                for h in hosts:
                    rlog.debug("     %s %s" % (
                        '*' if h == me.id else '-',
                        h))

    def register(self, mod):
        for name in dir(mod):
            v = getattr(mod, name)
            if type(v) is type and issubclass(v, ProtoBase):
                self.classes[v.__name__] = v

    def add_target_role(self, o):
        self.roles.add(o)
        self.role_by_name[o.name()] = o
        for k, v in o.__dict__.iteritems():
            if k.startswith("handle_") and callable(v):
                handler_for = k.split("_", 1)[-1]
                assert handler_for in self.classes, "protobuf class not recognized; register() the module"
                self.role_messages[o].append(handler_for)

    def host_specific_send(self, host, msg, typ, transport_cb):
        if host not in self.host_queues:
            q = Queue()
            fork(host_loop, host, q)
            self.host_queues[host] = q

        self.host_queues[host].put((msg, typ, transport_cb))

    def local_dispatch(self, env):
        if env.type not in self.classes:
            self.host_specific_send(env.node_id,
            MessageResponse(in_response_to=env.req_id,
                result=MessageResponse.REFUSED,
                error_message="cannot handle type"),
            MESSAGE_RES, None)
        elif me.id not in self.routes[env.type]:
            # use routes, balance, etc
            self.host_specific_send(env.node_id,
            MessageResponse(in_response_to=env.req_id,
                delivered=MessageResponse.REFUSED,
                error_message="do not own route"),
            MESSAGE_RES, None)
        else:
            inst = self.classes[env.type](env.body)
            r = self.local_handlers[env.type]
            sender = ConvoySender(env)
            back = MessageResponse(in_response_to=env.req_id,
                    delivered=MessageResponse.ACCEPTED)

            self.host_specific_send(env.node_id, back,
                    MESSAGE_RES, None)
            try:
                r(sender, inst)
            except Exception, e:
                s = str(e)
                back.result = MessageResponse.EXCEPTION
                back.error_message = s
                raise
            else:
예제 #28
0
파일: websockets.py 프로젝트: HVF/diesel
    def do_upgrade(self, req):
        if req.headers.get('Upgrade', '').lower() != 'websocket':
            return self.web_handler(req)

        headers = {}

        # do upgrade response
        org = req.headers.get('Origin')
        handshake_finish = None
        if 'Sec-WebSocket-Key' in req.headers:
            assert req.headers.get('Sec-WebSocket-Version') in ['8', '13'], \
                   "We currently only support Websockets version 8 and 13 (ver=%s)" % \
                   req.headers.get('Sec-WebSocket-Version')

            protocol = (req.headers.get('Sec-WebSocket-Protocol')
                        if 'Sec-WebSocket-Protocol' in req.headers else None)
            key = req.headers.get('Sec-WebSocket-Key')
            accept = b64encode(hashlib.sha1(key + self.GUID).digest())
            headers = {
                'Upgrade' : 'websocket',
                'Connection' : 'Upgrade',
                'Sec-WebSocket-Accept' : accept,
                }
            if protocol:
                headers["Sec-WebSocket-Protocol"] = protocol
        else:
            assert 0, "Only RFC 6455 WebSockets are supported"

        resp = Response(
                response='',
                status=101,
                headers=headers
                )

        self.send_response(resp)

        if handshake_finish:
            send(handshake_finish)

        inq = Queue()
        outq = Queue()

        def wrap(req, inq, outq):
            self.web_socket_handler(req, inq, outq)
            outq.put(WebSocketDisconnect())

        fork(wrap, req, inq, outq)

        while True:
            try:
                typ, val = first(receive=2, waits=[outq])
                if typ == 'receive':
                    b1, b2 = unpack(">BB", val)

                    opcode = b1 & 0x0f
                    fin = (b1 & 0x80) >> 7
                    has_mask = (b2 & 0x80) >> 7

                    assert has_mask == 1, "Frames must be masked"

                    if opcode == 8:
                        inq.put(WebSocketDisconnect())
                    else:
                        assert opcode == 1, "Currently only opcode 1 is supported (opcode=%s)" % opcode
                        length = b2 & 0x7f
                        if length == 126:
                            length = unpack('>H', receive(2))[0]
                        elif length == 127:
                            length = unpack('>L', receive(8))[0]

                        mask = unpack('>BBBB', receive(4))
                        payload = array('B', receive(length))
                        for i in xrange(len(payload)):
                            payload[i] ^= mask[i % 4]

                        try:
                            data = loads(payload.tostring())
                            inq.put(data)
                        except JSONDecodeError:
                            pass
                elif typ == outq:
                    if type(val) is WebSocketDisconnect:
                        b1 = 0x80 | (8 & 0x0f) # FIN + opcode
                        send(pack('>BB', b1, 0))
                        break
                    else:
                        payload = dumps(val)

                        b1 = 0x80 | (1 & 0x0f) # FIN + opcode

                        payload_len = len(payload)
                        if payload_len <= 125:
                            header = pack('>BB', b1, payload_len)
                        elif payload_len > 125 and payload_len < 65536:
                            header = pack('>BBH', b1, 126, payload_len)
                        elif payload_len >= 65536:
                            header = pack('>BBQ', b1, 127, payload_len)

                    send(header + payload)

            except ConnectionClosed:
                inq.put(WebSocketDisconnect())
                raise ConnectionClosed("remote disconnected")
예제 #29
0
class ConvoyNameService(object):
    def __init__(self, servers):
        self.servers = servers
        self.request_queue = Queue()
        self.pool_locks = {}

    def __call__(self):
        while True:
            server = random.choice(self.servers)
            with ConvoyConsensusClient(*server) as client:
                while True:
                    req, rq = self.request_queue.get()
                    if type(req) is ConvoyGetRequest:
                        resp = client.get(req.key)
                    elif type(req) is ConvoySetRequest:
                        resp = client.add_to_set(req.key, req.value, req.cap, req.timeout, req.lock)
                    elif type(req) is ConvoyWaitRequest:
                        resp = client.wait(req.timeout, req.clocks)
                    elif type(req) is ConvoyAliveRequest:
                        resp = client.keep_alive()
                    else:
                        assert 0
                    rq.put(resp)

    def lookup(self, key):
        rq = Queue()
        self.request_queue.put((ConvoyGetRequest(key), rq))
        return rq.get()

    def clear(self, key):
        rq = Queue()
        self.request_queue.put((ConvoySetRequest(key, None, 0, 5, 0), rq))
        return rq.get()

    def set(self, key, value):
        rq = Queue()
        self.request_queue.put((ConvoySetRequest(key, value, 0, 5, 0), rq))
        return rq.get()

    def add(self, key, value, cap, to=0):
        rq = Queue()
        self.request_queue.put((ConvoySetRequest(key, value, cap, to, 1), rq))
        return rq.get()

    def wait(self, timeout, clocks):
        rq = Queue()
        self.request_queue.put((ConvoyWaitRequest(timeout, clocks), rq))
        return rq.get()

    def alive(self):
        rq = Queue()
        self.request_queue.put((ConvoyAliveRequest(), rq))
        return rq.get()
예제 #30
0
class DieselZMQService(object):
    """A ZeroMQ service that can handle multiple clients.

    Clients must maintain a steady flow of messages in order to maintain
    state in the service. A heartbeat of some sort. Or the timeout can be
    set to a sufficiently large value understanding that it will cause more
    resource consumption.

    """
    name = ''
    default_log_level = loglevels.DEBUG
    timeout = 10

    def __init__(self, uri, logger=None, log_level=None):
        self.uri = uri
        self.zmq_socket = None
        self.log = logger or None
        self.selected_log_level = log_level
        self.clients = {}
        self.outgoing = Queue()
        self.incoming = Queue()
        self.name = self.name or self.__class__.__name__
        self._incoming_loop = None

        # Allow for custom `should_run` properties in subclasses.
        try:
            self.should_run = True
        except AttributeError:
            # A custom `should_run` property exists.
            pass

        if self.log and self.selected_log_level is not None:
            self.selected_log_level = None
            warnings.warn(
                "ignored `log_level` argument since `logger` was provided.",
                RuntimeWarning,
                stacklevel=2,
            )

    def _create_zeromq_server_socket(self):
        # TODO support other ZeroMQ socket types
        low_level_sock = zctx.socket(zmq.ROUTER)
        self.zmq_socket = DieselZMQSocket(low_level_sock, bind=self.uri)

    def _setup_the_logging_system(self):
        if not self.log:
            if self.selected_log_level is not None:
                log_level = self.selected_log_level
            else:
                log_level = self.default_log_level
            log_name = self.name or self.__class__.__name__
            self.log = log.name(log_name)
            self.log.min_level = log_level

    def _handle_client_requests_and_responses(self, remote_client):
        assert self.zmq_socket
        queues = [remote_client.incoming, remote_client.outgoing]
        try:
            while True:
                (evt, value) = diesel.first(waits=queues, sleep=self.timeout)
                if evt is remote_client.incoming:
                    assert isinstance(value, Message)
                    # Update return path with latest (in case of reconnect)
                    remote_client.zmq_return = value.zmq_return
                    resp = self.handle_client_packet(value.data,
                                                     remote_client.context)
                elif evt is remote_client.outgoing:
                    resp = value
                elif evt == 'sleep':
                    break
                if resp:
                    if isinstance(resp, basestring):
                        output = [resp]
                    else:
                        output = iter(resp)
                    for part in output:
                        msg = Message(
                            remote_client.identity,
                            part,
                        )
                        msg.zmq_return = remote_client.zmq_return
                        self.outgoing.put(msg)
        finally:
            self._cleanup_client(remote_client)

    def _cleanup_client(self, remote_client):
        del self.clients[remote_client.identity]
        self.cleanup_client(remote_client)
        self.log.debug("cleaned up client %r" % remote_client.identity)

    def _receive_incoming_messages(self):
        assert self.zmq_socket
        socket = self.zmq_socket
        while True:
            # TODO support receiving data from other socket types
            zmq_return_routing_data = socket.recv(copy=False)
            assert zmq_return_routing_data.more
            zmq_return_routing = zmq_return_routing_data.bytes
            packet_raw = socket.recv()
            msg = self.convert_raw_data_to_message(zmq_return_routing,
                                                   packet_raw)
            msg.zmq_return = zmq_return_routing
            self.incoming.put(msg)

    def _handle_all_inbound_and_outbound_traffic(self):
        assert self.zmq_socket
        self._incoming_loop = diesel.fork_child(
            self._receive_incoming_messages)
        self._incoming_loop.keep_alive = True
        queues = [self.incoming, self.outgoing]
        while self.should_run:
            (queue, msg) = diesel.first(waits=queues)

            if queue is self.incoming:
                if msg.remote_identity not in self.clients:
                    self._register_client(msg)
                self.clients[msg.remote_identity].incoming.put(msg)

            elif queue is self.outgoing:
                self.zmq_socket.send(msg.zmq_return, zmq.SNDMORE)
                self.zmq_socket.send(msg.data)

    def _register_client(self, msg):
        remote = RemoteClient.from_message(msg)
        self.clients[msg.remote_identity] = remote
        self.register_client(remote, msg)
        diesel.fork_child(self._handle_client_requests_and_responses, remote)

    # Public API
    # ==========

    def __call__(self):
        return self.run()

    def run(self):
        self._create_zeromq_server_socket()
        self._setup_the_logging_system()
        self._handle_all_inbound_and_outbound_traffic()

    def handle_client_packet(self, packet, context):
        """Called with a bytestring packet and dictionary context.

        Return an iterable of bytestrings.

        """
        raise NotImplementedError()

    def cleanup_client(self, remote_client):
        """Called with a RemoteClient instance. Do any cleanup you need to."""
        pass

    def register_client(self, remote_client, msg):
        """Called with a RemoteClient instance. Do any registration here."""
        pass

    def convert_raw_data_to_message(self, zmq_return, raw_data):
        """Subclasses can override to alter the handling of inbound data.

        Importantly, they can route the message based on the raw_data and
        even convert the raw_data to something more application specific
        and pass it to the Message constructor.

        This default implementation uses the zmq_return identifier for the
        remote socket as the identifier and passes the raw_data to the
        Message constructor.

        """
        return Message(zmq_return, raw_data)
예제 #31
0
class DieselNitroService(object):
    """A Nitro service that can handle multiple clients.

    Clients must maintain a steady flow of messages in order to maintain
    state in the service. A heartbeat of some sort. Or the timeout can be
    set to a sufficiently large value understanding that it will cause more
    resource consumption.

    """
    name = ''
    default_log_level = loglevels.DEBUG
    timeout = 10

    def __init__(self, uri, logger=None, log_level=None):
        self.uri = uri
        self.nitro_socket = None
        self.log = logger or None
        self.selected_log_level = log_level
        self.clients = {}
        self.outgoing = Queue()
        self.incoming = Queue()
        self.name = self.name or self.__class__.__name__
        self._incoming_loop = None

        # Allow for custom `should_run` properties in subclasses.
        try:
            self.should_run = True
        except AttributeError:
            # A custom `should_run` property exists.
            pass

        if self.log and self.selected_log_level is not None:
            self.selected_log_level = None
            warnings.warn(
                "ignored `log_level` argument since `logger` was provided.",
                RuntimeWarning,
                stacklevel=2,
            )

    def _create_server_socket(self):
        self.nitro_socket = DieselNitroSocket(bind=self.uri)

    def _setup_the_logging_system(self):
        if not self.log:
            if self.selected_log_level is not None:
                log_level = self.selected_log_level
            else:
                log_level = self.default_log_level
            log_name = self.name or self.__class__.__name__
            self.log = log.name(log_name)
            self.log.min_level = log_level

    def _handle_client_requests_and_responses(self, remote_client):
        assert self.nitro_socket
        queues = [remote_client.incoming]
        try:
            while True:
                (evt, value) = diesel.first(waits=queues, sleep=self.timeout)
                if evt is remote_client.incoming:
                    assert isinstance(value, Message)
                    remote_client.async_frame = value.orig_frame
                    resp = self.handle_client_packet(value.data,
                                                     remote_client.context)
                    if resp:
                        if isinstance(resp, basestring):
                            output = [resp]
                        else:
                            output = iter(resp)
                        for part in output:
                            msg = Message(
                                value.orig_frame,
                                remote_client.identity,
                                self.serialize_message(remote_client.identity,
                                                       part),
                            )
                            self.outgoing.put(msg)
                elif evt == 'sleep':
                    break
        finally:
            self._cleanup_client(remote_client)

    def _cleanup_client(self, remote_client):
        del self.clients[remote_client.identity]
        self.cleanup_client(remote_client)
        self.log.debug("cleaned up client %r" % remote_client.identity)

    def _handle_all_inbound_and_outbound_traffic(self):
        assert self.nitro_socket
        queues = [self.nitro_socket, self.outgoing]
        socket = self.nitro_socket
        make_frame = pynitro.NitroFrame
        while self.should_run:
            (queue, msg) = diesel.first(waits=queues)

            if queue is self.outgoing:
                socket.reply(msg.orig_frame, make_frame(msg.data))
            else:
                id, obj = self.parse_message(msg.data)
                msg.clear_data()
                msg = Message(msg, id, obj)
                if msg.identity not in self.clients:
                    self._register_client(msg)
                self.clients[msg.identity].incoming.put(msg)

    def _register_client(self, msg):
        remote = RemoteClient.from_message(msg)
        self.clients[msg.identity] = remote
        self.register_client(remote, msg)
        diesel.fork_child(self._handle_client_requests_and_responses, remote)

    # Public API
    # ==========

    def __call__(self):
        return self.run()

    def run(self):
        self._create_server_socket()
        self._setup_the_logging_system()
        self._handle_all_inbound_and_outbound_traffic()

    def handle_client_packet(self, packet, context):
        """Called with a bytestring packet and dictionary context.

        Return an iterable of bytestrings.

        """
        raise NotImplementedError()

    def cleanup_client(self, remote_client):
        """Called with a RemoteClient instance. Do any cleanup you need to."""
        pass

    def register_client(self, remote_client, msg):
        """Called with a RemoteClient instance. Do any registration here."""
        pass

    def parse_message(self, raw_data):
        """Subclasses can override to alter the handling of inbound data.

        Transform an incoming bytestring into a structure (aka, json.loads)
        """
        return None, raw_data

    def serialize_message(self, identity, raw_data):
        """Subclasses can override to alter the handling of outbound data.

        Turn some structure into a bytestring (aka, json.dumps)
        """
        return raw_data

    def async_send(self, identity, msg):
        """Raises KeyError if client is no longer connected.
        """
        remote_client = self.clients[identity]
        out = self.serialize_message(msg)
        self.outgoing.put(Message(remote_client.async_frame, identity, out))
예제 #32
0
파일: zeromq.py 프로젝트: arnaudsj/diesel
class DieselZMQService(object):
    """A ZeroMQ service that can handle multiple clients.

    Clients must maintain a steady flow of messages in order to maintain
    state in the service. A heartbeat of some sort. Or the timeout can be
    set to a sufficiently large value understanding that it will cause more
    resource consumption.

    """
    name = ''
    default_log_level = loglevels.DEBUG
    timeout = 10

    def __init__(self, uri, logger=None, log_level=None):
        self.uri = uri
        self.zmq_socket = None
        self.log = logger or None
        self.selected_log_level = log_level
        self.clients = {}
        self.outgoing = Queue()
        self.incoming = Queue()
        self.name = self.name or self.__class__.__name__
        if self.log and self.selected_log_level is not None:
            self.selected_log_level = None
            warnings.warn(
                "ignored `log_level` argument since `logger` was provided.",
                RuntimeWarning,
                stacklevel=2,
            )

    def _create_zeromq_server_socket(self):
        # TODO support other ZeroMQ socket types
        low_level_sock = zctx.socket(zmq.ROUTER)
        self.zmq_socket = DieselZMQSocket(low_level_sock, bind=self.uri)

    def _setup_the_logging_system(self):
        if not self.log:
            if self.selected_log_level is not None:
                log_level = self.selected_log_level
            else:
                log_level = self.default_log_level
            log_name = self.name or self.__class__.__name__
            self.log = log.name(log_name)
            self.log.min_level = log_level

    def _handle_client_requests_and_responses(self, remote_client):
        assert self.zmq_socket
        queues = [remote_client.incoming, remote_client.outgoing]
        try:
            while True:
                (evt, value) = diesel.first(waits=queues, sleep=self.timeout)
                if evt is remote_client.incoming:
                    assert isinstance(value, Message)
                    resp = self.handle_client_packet(value.data, remote_client.context)
                elif evt is remote_client.outgoing:
                    resp = value
                elif evt == 'sleep':
                    break
                if resp:
                    if isinstance(resp, basestring):
                        output = [resp]
                    else:
                        output = iter(resp)
                    for part in output:
                        msg = Message(
                            remote_client.identity,
                            part,
                        )
                        msg.zmq_return = remote_client.zmq_return
                        self.outgoing.put(msg)
        finally:
            self._cleanup_client(remote_client)

    def _cleanup_client(self, remote_client):
        del self.clients[remote_client.identity]
        self.cleanup_client(remote_client)
        self.log.debug("cleaned up client %r" % remote_client.identity)

    def _receive_incoming_messages(self):
        assert self.zmq_socket
        socket = self.zmq_socket
        while True:
            # TODO support receiving data from other socket types
            zmq_return_routing_data = socket.recv(copy=False)
            assert zmq_return_routing_data.more
            zmq_return_routing = zmq_return_routing_data.bytes
            packet_raw = socket.recv()
            msg = self.convert_raw_data_to_message(zmq_return_routing, packet_raw)
            msg.zmq_return = zmq_return_routing
            self.incoming.put(msg)

    def _handle_all_inbound_and_outbound_traffic(self):
        assert self.zmq_socket
        diesel.fork_child(self._receive_incoming_messages)
        queues = [self.incoming, self.outgoing]
        while True:
            (queue, msg) = diesel.first(waits=queues)

            if queue is self.incoming:
                if msg.remote_identity not in self.clients:
                    self._register_client(msg)
                self.clients[msg.remote_identity].incoming.put(msg)

            elif queue is self.outgoing:
                self.zmq_socket.send(msg.zmq_return, zmq.SNDMORE)
                self.zmq_socket.send(msg.data)

    def _register_client(self, msg):
        remote = RemoteClient.from_message(msg)
        self.clients[msg.remote_identity] = remote
        self.register_client(remote, msg)
        diesel.fork_child(self._handle_client_requests_and_responses, remote)

    # Public API
    # ==========

    def run(self):
        self._create_zeromq_server_socket()
        self._setup_the_logging_system()
        self._handle_all_inbound_and_outbound_traffic()

    def handle_client_packet(self, packet, context):
        """Called with a bytestring packet and dictionary context.

        Return an iterable of bytestrings.

        """
        raise NotImplementedError()

    def cleanup_client(self, remote_client):
        """Called with a RemoteClient instance. Do any cleanup you need to."""
        pass

    def register_client(self, remote_client, msg):
        """Called with a RemoteClient instance. Do any registration here."""
        pass

    def convert_raw_data_to_message(self, zmq_return, raw_data):
        """Subclasses can override to alter the handling of inbound data.

        Importantly, they can route the message based on the raw_data and
        even convert the raw_data to something more application specific
        and pass it to the Message constructor.

        This default implementation uses the zmq_return identifier for the
        remote socket as the identifier and passes the raw_data to the
        Message constructor.

        """
        return Message(zmq_return, raw_data)