예제 #1
0
def parse(mock_write):
    unpacker = Unpacker()
    for call in mock_write.call_args_list:
        unpacker.feed(call[0][0])
    results = []
    for msg in unpacker:
        results.append(msg)
    return results
예제 #2
0
class BaseProtocol(object):

    def __init__(self):
        self.unpacker = Unpacker()

    def connection_made(self):
        pass

    def connection_lost(self, reason):
        pass

    def protocol_error(self, reason):
        pass

    def on_error(self, error):
        raise NotImplementedError(self.on_info)

    def on_info(self, name, rand):
        raise NotImplementedError(self.on_info)

    def on_auth(self, ident, hash):
        raise NotImplementedError(self.on_auth)

    def on_publish(self, ident, chan, data):
        raise NotImplementedError(self.on_publish)

    def on_subscribe(self, ident, channel):
        raise NotImplementedError(self.on_subscribe)

    def on_unsubscribe(self, ident, channel):
        raise NotImplementedError(self.on_unsubscribe)

    def message_received(self, opcode, data):
        if opcode == OP_ERROR:
            return self.on_error(readerror(data))
        elif opcode == OP_INFO:
            return self.on_info(*readinfo(data))
        elif opcode == OP_AUTH:
            return self.on_auth(*readauth(data))
        elif opcode == OP_PUBLISH:
            return self.on_publish(*readpublish(data))
        elif opcode == OP_SUBSCRIBE:
            return self.on_subscribe(*readsubscribe(data))
        elif opcode == OP_UNSUBSCRIBE:
            return self.on_unsubscribe(*readunsubscribe(data))

        # Can't recover from an unknown opcode, so drop connection
        self.protocol_error('Unknown message opcode: {!r}'.format(opcode))
        self.transport.close()

    def data_received(self, data):
        self.unpacker.feed(data)
        try:
            for opcode, data in self.unpacker:
                self.message_received(opcode, data)
        except ProtocolException as e:
            self.protocol_error(str(e))
            self.transport.close()
예제 #3
0
    def test_unpack_2(self):
        message = msghdr(1, b'abcdefghijklmnopqrstuvwxyz')
        unpacker = Unpacker()

        # The unpacker shouldn't yield any messages until it has consumed the
        # full object
        for b in message[:-1]:
            unpacker.feed([b])
            assert list(iter(unpacker)) == []

        unpacker.feed([message[-1]])
        assert list(iter(unpacker)) == [(1, b'abcdefghijklmnopqrstuvwxyz')]
예제 #4
0
class BaseProtocol(Protocol):
    def __init__(self):
        self.unpacker = Unpacker()

    def protocolError(self, reason):
        '''
        Called when an unrecoverable protocol error has been detected. The
        connection will be dropped.
        '''
        log.err(reason)

    def onError(self, error):
        '''
        Called by messageReceived when an OP_ERROR has been parsed.
        '''
        raise NotImplementedError(self.onError)

    def onInfo(self, name, rand):
        '''
        Called by messageReceived when an OP_INFO has been parsed.
        '''
        raise NotImplementedError(self.onInfo)

    def onAuth(self, ident, secret):
        '''
        Called by messageReceived when an OP_AUTH has been parsed.
        '''
        raise NotImplementedError(self.onAuth)

    def onPublish(self, ident, chan, data):
        '''
        Called by messageReceived when an OP_PUBLISH has been parsed.
        '''
        raise NotImplementedError(self.onPublish)

    def onSubscribe(self, ident, chan):
        '''
        Called by messageReceived when an OP_SUBSCRIBE has been parsed.
        '''
        raise NotImplementedError(self.onSubscribe)

    def onUnsubscribe(self, ident, chan):
        '''
        Called by messageReceived when an OP_UNSUBSCRIBE has been parsed.
        '''
        raise NotImplementedError(self.onUnsubscribe)

    def messageReceived(self, opcode, data):
        if opcode == OP_ERROR:
            return self.onError(readerror(data))
        elif opcode == OP_INFO:
            return self.onInfo(*readinfo(data))
        elif opcode == OP_AUTH:
            return self.onAuth(*readauth(data))
        elif opcode == OP_PUBLISH:
            return self.onPublish(*readpublish(data))
        elif opcode == OP_SUBSCRIBE:
            return self.onSubscribe(*readsubscribe(data))
        elif opcode == OP_UNSUBSCRIBE:
            return self.onUnsubscribe(*readunsubscribe(data))

        # Can't recover from an unknown opcode, so drop connection
        self.protocolError('Unknown message opcode: {!r}'.format(opcode))
        self.transport.loseConnection()

    def dataReceived(self, data):
        self.unpacker.feed(data)
        try:
            for opcode, data in self.unpacker:
                self.messageReceived(opcode, data)
        except ProtocolException as e:
            # Can't recover from a protocol decoding error, so drop connection
            self.protocolError(str(e))
            self.transport.loseConnection()

    def error(self, error):
        self.transport.write(msgerror(error))

    def info(self, name, rand):
        self.transport.write(msginfo(name, rand))

    def auth(self, rand, ident, secret):
        self.transport.write(msgauth(rand, ident, secret))

    def publish(self, ident, channel, payload):
        self.transport.write(msgpublish(ident, channel, payload))

    def subscribe(self, ident, channel):
        self.transport.write(msgsubscribe(ident, channel))

    def unsubscribe(self, ident, channel):
        self.transport.write(msgunsubscribe(ident, channel))
예제 #5
0
class Connection(object):
    def __init__(self, sock, addr, srv):
        self.sock = sock
        self.addr = addr
        self.srv = srv
        self.uid = None
        self.ak = None
        self.pubchans = []
        self.subchans = []
        self.active = True
        self.unpacker = Unpacker()

    def __del__(self):
        # if this message is not showing up we're leaking references
        log.debug("Connection cleanup {0}".format(self.addr))

    def write(self, data):
        try:
            self.sock.sendall(data)
            #  self.stats['bytes_sent'] += len(data)
        except Exception as e:
            log.critical('Exception when writing to conn', exc_info=e)

    def handle(self):
        # first send the info message
        self.authrand = authrand = os.urandom(4)
        self.write(msginfo(config.FBNAME, authrand))

        while True:
            self.unpacker.feed(self.s.recv(BUFSIZ))
            for opcode, data in self.unpacker:
                if opcode != OP_AUTH:
                    self.error('First message was not AUTH.')
                    raise BadClient()

                ident, rhash = readauth(data)
                self.authkey_check(ident, rhash)
                break

        while True:
            self.unpacker.feed(self.recv(BUFSIZ))

            for opcode, data in self.unpacker:
                if opcode == OP_PUBLISH:
                    self.do_publish(*readpublish(data))
                elif opcode == OP_SUBSCRIBE:
                    self.do_subscribe(*readsubscribe(data))
                elif opcode == OP_UNSUBSCRIBE:
                    self.do_unsubscribe(*readsubscribe(data))
                else:
                    self.error(
                        "Unknown message type.",
                        opcode=opcode,
                        length=len(data),
                    )
                    raise BadClient()

    def do_publish(self, ident, chan, payload):
        if not ident == self.ak:
            self.error("Invalid authkey in message.", ident=ident)
            raise BadClient()

        if chan not in self.pubchans or chan.endswith("..broker"):
            self.error("Authkey not allowed to publish here.", chan=chan)
            return

        self.srv.do_publish(self, chan, payload)
        #  self.stats["published"] += 1

    def do_subscribe(self, ident, chan):
        checkchan = chan
        if chan.endswith('..broker'):
            checkchan = chan.rsplit('..broker', 1)[0]

        if checkchan not in self.subchans:
            self.error(
                "Authkey not allowed to subscribe here.",
                chan=chan,
            )
            return

        self.srv.do_subscribe(self, ident, chan)

    def do_unsubscribe(self, ident, chan):
        self.do_unsubscribe(self, ident, chan)

    def authkey_check(self, ident, rhash):
        akrow = self.srv.get_authkey(ident)
        if not akrow:
            self.error("Authentication failed.", ident=ident)
            raise BadClient()

        akhash = hashsecret(self.authrand, akrow["secret"])
        if not akhash == rhash:
            self.error("Authentication failed.", ident=ident)
            raise BadClient()

        self.ak = ident
        self.uid = akrow["owner"]
        self.pubchans = akrow.get("pubchans", [])
        self.subchans = akrow.get("subchans", [])

    def forward(self, ident, chan, data):
        self.write(msgpublish(ident, chan, data))
        #  self.stats['received'] += 1

    def error(self, msg, *args, **context):
        emsg = msg.format(*args)
        log.critical(emsg)
        self.srv.log_error(emsg, self, context)
        self.write(msgerror(emsg))
예제 #6
0
 def test_unpack_1(self):
     unpacker = Unpacker()
     unpacker.feed(msghdr(1, b'abcdefghijklmnopqrstuvwxyz'))
     packets = list(iter(unpacker))
     assert packets == [(1, b'abcdefghijklmnopqrstuvwxyz')]