Exemplo n.º 1
0
class IOCore(object):

    marshal = None
    name = 'Core API'
    default_target = None

    def __init__(self, debug=False, timeout=3, do_connect=False,
                 host=None, key=None, cert=None, ca=None,
                 addr=None, fork=False, secret=None):
        addr = addr or uuid32()
        self._timeout = timeout
        self.default_broker = addr
        self.default_dport = 0
        self.uids = set()
        self.listeners = {}     # {nonce: Queue(), ...}
        self.callbacks = []     # [(predicate, callback, args), ...]
        self.debug = debug
        self.cid = None
        self.cmd_nonce = AddrPool(minaddr=0xf, maxaddr=0xfffe)
        self.nonce = AddrPool(minaddr=0xffff, maxaddr=0xffffffff)
        self.emarshal = MarshalEnv()
        self.save = None
        if self.marshal is not None:
            self.marshal.debug = debug
            self.marshal = self.marshal()
        self.buffers = Queue.Queue()
        self._mirror = False
        self.host = host

        self.ioloop = IOLoop()

        self._brs, self.bridge = pairPipeSockets()
        # To fork or not to fork?
        #
        # It depends on how you gonna use RT netlink sockets
        # in your application. Performance can also differ,
        # and sometimes forked broker can even speedup your
        # application -- keep in mind Python's GIL
        if fork:
            # Start the I/O broker in a separate process,
            # so you can use multiple RT netlink sockets in
            # one application -- it does matter, if your
            # application already uses some RT netlink
            # library and you want to smoothly try and
            # migrate to the pyroute2
            self.forked_broker = Process(target=self._start_broker,
                                         args=(fork, secret))
            self.forked_broker.start()
        else:
            # Start I/O broker as an embedded object, so
            # the RT netlink socket will be opened in the
            # same process. Technically, you can open
            # multiple RT netlink sockets within one process,
            # but only the first one will receive all the
            # answers -- so effectively only one socket per
            # process can be used.
            self.forked_broker = None
            self._start_broker(fork, secret)
        self.ioloop.start()
        self.ioloop.register(self.bridge,
                             self._route,
                             defer=True)
        if do_connect:
            path = urlparse.urlparse(host).path
            (self.default_link,
             self.default_peer) = self.connect(self.host,
                                               key, cert, ca)
            self.default_dport = self.discover(self.default_target or path,
                                               self.default_peer)

    def _start_broker(self, fork=False, secret=None):
        iobroker = IOBroker(addr=self.default_broker,
                            ioloop=None if fork else self.ioloop,
                            control=self._brs, secret=secret)
        iobroker.start()
        if fork:
            iobroker._stop_event.wait()

    @debug
    def _route(self, sock, raw):
        data = io.BytesIO()
        data.length = data.write(raw)

        for envelope in self.emarshal.parse(data, sock):

            nonce = envelope['header']['sequence_number']
            flags = envelope['header']['flags']
            try:
                buf = io.BytesIO()
                buf.length = buf.write(envelope.
                                       get_attr('IPR_ATTR_CDATA'))
                buf.seek(0)
                if ((flags & NLT_CONTROL) and
                        (flags & NLT_RESPONSE)):
                    msg = mgmtmsg(buf)
                    msg.decode()
                    self.listeners[nonce].put_nowait(msg)
                else:
                    self.parse(envelope, buf)
            except AttributeError:
                # now silently drop bad packet
                pass

    @debug
    def parse(self, envelope, data):

        if self.marshal is None:
            nonce = envelope['header']['sequence_number']
            if envelope['header']['flags'] & NLT_EXCEPTION:
                error = RuntimeError(data.getvalue())
                msgs = [{'header': {'sequence_number': nonce,
                                    'type': 0,
                                    'flags': 0,
                                    'error': error},
                         'data': None}]
            else:
                msgs = [{'header': {'sequence_number': nonce,
                                    'type': 0,
                                    'flags': 0,
                                    'error': None},
                         'data': data.getvalue()}]
        else:
            msgs = self.marshal.parse(data)

        for msg in msgs:
            try:
                key = msg['header']['sequence_number']
            except (TypeError, KeyError):
                key = 0

            # 8<--------------------------------------------------------------
            # message filtering
            # right now it is simply iterating callback list
            # .. _ioc-callbacks:
            skip = False

            for cr in self.callbacks:
                if cr[0](envelope, msg):
                    if cr[1](envelope, msg, *cr[2]) is not None:
                        skip = True

            if skip:
                continue

            # 8<--------------------------------------------------------------
            if key not in self.listeners:
                key = 0

            if self._mirror and (key != 0):
                # On Python 2.6 it can fail due to class fabrics
                # in nlmsg definitions, so parse it again. It should
                # not be much slower than copy.deepcopy()
                if getattr(msg, 'raw', None) is not None:
                    raw = io.BytesIO()
                    raw.length = raw.write(msg.raw)
                    new = self.marshal.parse(raw)[0]
                else:
                    new = copy.deepcopy(msg)
                self.listeners[0].put_nowait(new)

            if key in self.listeners:
                try:
                    self.listeners[key].put_nowait(msg)
                except Queue.Full:
                    # FIXME: log this
                    pass

    def command(self, cmd, attrs=[], expect=None, addr=None):
        addr = addr or self.default_broker
        msg = mgmtmsg(io.BytesIO())
        msg['cmd'] = cmd
        msg['attrs'] = attrs
        msg['header']['type'] = NLMSG_CONTROL
        msg.encode()
        rsp = self.request(msg.buf.getvalue(),
                           env_flags=NLT_CONTROL,
                           nonce_pool=self.cmd_nonce,
                           addr=addr)[0]
        if rsp['cmd'] != IPRCMD_ACK:
            raise RuntimeError(rsp.get_attr('IPR_ATTR_ERROR'))
        if expect is not None:
            if type(expect) not in (list, tuple):
                return rsp.get_attr(expect)
            else:
                ret = []
                for item in expect:
                    ret.append(rsp.get_attr(item))
                return ret
        else:
            return None

    def unregister(self, addr=None):
        return self.command(IPRCMD_UNREGISTER, addr=addr)

    def register(self, secret, addr=None):
        return self.command(IPRCMD_REGISTER,
                            [['IPR_ATTR_SECRET', secret]],
                            addr=addr)

    def discover(self, url, addr=None):
        # .. _ioc-discover:
        return self.command(IPRCMD_DISCOVER,
                            [['IPR_ATTR_HOST', url]],
                            expect='IPR_ATTR_ADDR',
                            addr=addr)

    def provide(self, url):
        self.command(IPRCMD_PROVIDE, [['IPR_ATTR_HOST', url]])
        return self.command(IPRCMD_CONNECT, [['IPR_ATTR_HOST', url]])

    def remove(self, url):
        return self.command(IPRCMD_REMOVE, [['IPR_ATTR_HOST', url]])

    def serve(self, url, key='', cert='', ca='', addr=None):
        return self.command(IPRCMD_SERVE,
                            [['IPR_ATTR_HOST', url],
                             ['IPR_ATTR_SSL_KEY', key],
                             ['IPR_ATTR_SSL_CERT', cert],
                             ['IPR_ATTR_SSL_CA', ca]],
                            addr=addr)

    def shutdown(self, url, addr=None):
        return self.command(IPRCMD_SHUTDOWN,
                            [['IPR_ATTR_HOST', url]],
                            addr=addr)

    def connect(self, host=None, key='', cert='', ca='', addr=None):
        host = host or self.host
        (uid,
         peer) = self.command(IPRCMD_CONNECT,
                              [['IPR_ATTR_HOST', host],
                               ['IPR_ATTR_SSL_KEY', key],
                               ['IPR_ATTR_SSL_CERT', cert],
                               ['IPR_ATTR_SSL_CA', ca]],
                              expect=['IPR_ATTR_UUID',
                                      'IPR_ATTR_ADDR'],
                              addr=addr)
        self.uids.add((uid, addr))
        return uid, peer

    def disconnect(self, uid, addr=None):
        ret = self.command(IPRCMD_DISCONNECT,
                           [['IPR_ATTR_UUID', uid]],
                           addr=addr)
        self.uids.remove((uid, addr))
        return ret

    def release(self):
        '''
        Shutdown all threads and release netlink sockets
        '''
        for (uid, addr) in tuple(self.uids):
            try:
                self.disconnect(uid, addr=addr)
            except Queue.Empty as e:
                if addr == self.default_broker:
                    raise e
        self.command(IPRCMD_STOP)
        if self.forked_broker:
            self.forked_broker.join()
        self.ioloop.shutdown()
        self.ioloop.join()

        self._brs.send(struct.pack('I', 4))
        self._brs.close()
        self.bridge.close()

    def mirror(self, operate=True):
        '''
        Turn message mirroring on/off. When it is 'on', all
        received messages will be copied (mirrored) into the
        default 0 queue.
        '''
        self._mirror = operate

    def monitor(self, operate=True):
        '''
        Create/destroy the default 0 queue. Netlink socket
        receives messages all the time, and there are many
        messages that are not replies. They are just
        generated by the kernel as a reflection of settings
        changes. To start receiving these messages, call
        Netlink.monitor(). They can be fetched by
        Netlink.get(0) or just Netlink.get().
        '''
        if operate and self.cid is None:
            self.listeners[0] = Queue.Queue(maxsize=_QUEUE_MAXSIZE)
            self.cid = self.command(IPRCMD_SUBSCRIBE,
                                    [['IPR_ATTR_KEY', {'offset': 8,
                                                       'key': 0,
                                                       'mask': 0}]],
                                    expect='IPR_ATTR_CID')
        else:
            self.command(IPRCMD_UNSUBSCRIBE,
                         [['IPR_ATTR_CID', self.cid]])
            self.cid = None
            del self.listeners[0]

    def register_callback(self, callback,
                          predicate=lambda e, x: True, args=None):
        '''
        Register a callback to run on a message arrival.

        Callback is the function that will be called with the
        message as the first argument. Predicate is the optional
        callable object, that returns True or False. Upon True,
        the callback will be called. Upon False it will not.
        Args is a list or tuple of arguments.

        Simplest example, assume ipr is the IPRoute() instance::

            # create a simplest callback that will print messages
            def cb(env, msg):
                print(msg)

            # register callback for any message:
            ipr.register_callback(cb)

        More complex example, with filtering::

            # Set object's attribute after the message key
            def cb(env, msg, obj):
                obj.some_attr = msg["some key"]

            # Register the callback only for the loopback device, index 1:
            ipr.register_callback(cb,
                                  lambda e, x: x.get('index', None) == 1,
                                  (self, ))

        Please note: you do **not** need to register the default 0 queue
        to invoke callbacks on broadcast messages. Callbacks are
        iterated **before** messages get enqueued.
        '''
        if args is None:
            args = []
        self.callbacks.append((predicate, callback, args))

    def unregister_callback(self, callback):
        '''
        Remove the first reference to the function from the callback
        register
        '''
        cb = tuple(self.callbacks)
        for cr in cb:
            if cr[1] == callback:
                self.callbacks.pop(cb.index(cr))
                return

    @debug
    def get(self, key=0, raw=False, timeout=None, terminate=None,
            nonce_pool=None):
        '''
        Get a message from a queue

        * key -- message queue number
        '''
        nonce_pool = nonce_pool or self.nonce
        queue = self.listeners[key]
        result = []
        e = None
        timeout = (timeout or self._timeout) if (key != 0) else 0xffff
        while True:
            # timeout should also be set to catch ctrl-c
            # Bug-Url: http://bugs.python.org/issue1360
            try:
                msg = queue.get(block=True, timeout=timeout)
            except Queue.Empty as x:
                e = x
                if key == 0:
                    continue
                else:
                    break

            if (terminate is not None) and terminate(msg):
                break

            # exceptions
            if msg['header'].get('error', None) is not None:
                e = msg['header']['error']

            # RPC
            if self.marshal is None:
                data = msg.get('data', msg)
            else:
                data = msg

            # Netlink
            if (msg['header']['type'] != NLMSG_DONE):
                result.append(data)

            # break the loop if any
            if (key == 0) or (e is not None):
                break

            # wait for NLMSG_DONE if NLM_F_MULTI
            if (terminate is None) and (
                    (msg['header']['type'] == NLMSG_DONE) or
                    (not msg['header']['flags'] & NLM_F_MULTI)):
                break

        if key != 0:
            # delete the queue
            del self.listeners[key]
            nonce_pool.free(key)
            # get remaining messages from the queue and
            # re-route them to queue 0 or drop
            while not queue.empty():
                msg = queue.get()
                if 0 in self.listeners:
                    self.listeners[0].put(msg)

        if e is not None:
            raise e

        return result

    @debug
    def push(self, host, msg,
             env_flags=None,
             nonce=0,
             cname=None):
        addr, port = host
        envelope = envmsg()
        envelope['header']['sequence_number'] = nonce
        envelope['header']['pid'] = os.getpid()
        envelope['header']['type'] = NLMSG_TRANSPORT
        if env_flags is not None:
            envelope['header']['flags'] = env_flags
        envelope['dst'] = addr
        envelope['src'] = self.default_broker
        envelope['dport'] = port
        envelope['ttl'] = 16
        envelope['id'] = uuid.uuid4().bytes
        envelope['attrs'] = [['IPR_ATTR_CDATA', msg]]
        if cname is not None:
            envelope['attrs'].append(['IPR_ATTR_CNAME', cname])
        envelope.encode()
        self.bridge.send(envelope.buf.getvalue())

    def request(self, msg,
                env_flags=0,
                addr=None,
                port=None,
                nonce=None,
                nonce_pool=None,
                cname=None,
                response_timeout=None,
                terminate=None):
        nonce_pool = nonce_pool or self.nonce
        nonce = nonce or nonce_pool.alloc()
        port = port or self.default_dport
        addr = addr or self.default_broker
        self.listeners[nonce] = Queue.Queue(maxsize=_QUEUE_MAXSIZE)
        self.push((addr, port), msg, env_flags, nonce, cname)
        return self.get(nonce,
                        nonce_pool=nonce_pool,
                        timeout=response_timeout,
                        terminate=terminate)
Exemplo n.º 2
0
class IOBroker(object):
    def __init__(self,
                 addr=0x01000000,
                 broadcast=0xffffffff,
                 ioloop=None,
                 control=None,
                 secret=None):
        self.pid = os.getpid()
        self._stop_event = threading.Event()
        self._reload_event = threading.Event()
        self.shutdown_flag = threading.Event()
        self.addr = addr
        self.broadcast = broadcast
        self.marshal = MarshalEnv()
        self.ports = AddrPool(minaddr=0xff)
        self.nonces = AddrPool(minaddr=0xfff)
        self.active_sys = {}
        self.local = {}
        self.links = {}
        self.remote = {}
        self.discover = {}
        # fd lists for select()
        self._rlist = set()
        self._wlist = set()
        self._xlist = set()
        # routing
        self.masquerade = {}      # {int: MasqRecord()...}
        self.packet_ids = {}      # {int: CacheRecord()...}
        self.clients = set()      # set(socket, socket...)
        self.servers = set()      # set(socket, socket...)
        self.controls = set()     # set(socket, socket...)
        self.sockets = {}
        self.subscribe = {}
        self.providers = {}
        # modules = { IPRCMD_STOP: {'access': access.ADMIN,
        #                           'command': <function>},
        #             IPRCMD_...: ...
        self.modules = dict(((x.target, {'access': x.level,
                                         'command': x.command})
                             for x in modules))
        self._cid = list(range(1024))
        # secret; write non-zero byte as terminator
        if secret:
            self.secret = secret
        else:
            self.secret = os.urandom(15)
            self.secret += b'\xff'
        self.uuid = uuid.uuid4()
        # masquerade cache expiration
        self._expire_thread = Cache(self._stop_event)
        self._expire_thread.register_map(self.masquerade, self.nonces.free)
        self._expire_thread.register_map(self.packet_ids)
        self._expire_thread.setDaemon(True)
        if ioloop:
            self.ioloop = ioloop
            self.standalone = False
        else:
            self.ioloop = IOLoop()
            self.standalone = True
        if control:
            self.add_client(control)
            self.controls.add(control)
            self.ioloop.register(control, self.route, defer=True)

    def handle_connect(self, fd, event):
        (client, addr) = fd.accept()
        self.add_client(client)
        # announce address
        # .. _ioc-connect:
        rsp = mgmtmsg()
        rsp['header']['type'] = NLMSG_CONTROL
        rsp['cmd'] = IPRCMD_ACK
        rsp['attrs'] = [['IPR_ATTR_ADDR', self.addr]]
        rsp.encode()
        ne = envmsg()
        ne['dst'] = self.broadcast
        ne['id'] = uuid.uuid4().bytes
        ne['header']['pid'] = os.getpid()
        ne['header']['type'] = NLMSG_TRANSPORT
        ne['header']['flags'] = NLT_CONTROL | NLT_RESPONSE
        ne['attrs'] = [['IPR_ATTR_CDATA',
                        rsp.buf.getvalue()]]
        ne.encode()
        client.send(ne.buf.getvalue())
        self.ioloop.register(client, self.route, defer=True)

    def alloc_addr(self):
        return self.ports.alloc()

    def dealloc_addr(self, addr):
        self.ports.free(addr)

    def route_control(self, sock, envelope):
        pid = envelope['header']['pid']
        nonce = envelope['header']['sequence_number']
        # src = envelope['src']
        dst = envelope['dst']
        sport = envelope['sport']
        dport = envelope['dport']
        data = io.BytesIO(envelope.get_attr('IPR_ATTR_CDATA'))
        cmd = self.parse_control(data)
        module = cmd['cmd']
        rsp = mgmtmsg()
        rsp['header']['type'] = NLMSG_CONTROL
        rsp['header']['sequence_number'] = nonce
        rsp['cmd'] = IPRCMD_ERR
        rsp['attrs'] = []

        rights = 0
        if sock in self.controls:
            rights = access.ADMIN
        elif sock in self.clients:
            rights = access.USER
        try:
            if rights & self.modules[module]['access']:
                self.modules[module]['command'](self, sock, envelope,
                                                cmd, rsp)
            else:
                raise IOError(13, 'Permission denied')

            rsp['cmd'] = IPRCMD_ACK
            rsp['attrs'].append(['IPR_ATTR_SOURCE', cmd['cmd']])
        except Exception:
            rsp['attrs'] = [['IPR_ATTR_ERROR', traceback.format_exc()]]

        rsp.encode()
        ne = envmsg()
        ne['header']['sequence_number'] = nonce
        ne['header']['pid'] = pid
        ne['header']['type'] = NLMSG_TRANSPORT
        ne['header']['flags'] = NLT_CONTROL | NLT_RESPONSE
        ne['src'] = dst
        ne['ttl'] = 16
        ne['id'] = uuid.uuid4().bytes
        ne['dport'] = sport
        ne['sport'] = dport
        ne['attrs'] = [['IPR_ATTR_CDATA', rsp.buf.getvalue()]]
        ne.encode()
        sock.send(ne.buf.getvalue())

        if self.shutdown_flag.is_set():
            self.shutdown()

    def route_forward(self, sock, envelope):
        nonce = envelope['header']['sequence_number']

        envelope['ttl'] -= 1
        if envelope['ttl'] <= 0:
            return

        if (envelope['dst'] == 0) and (nonce in self.masquerade):
            return self.unmasq(nonce, envelope)
        else:
            flags = envelope['header']['flags']
            for (uid, link) in self.remote.items():
                # by default, send packets only via SOCK_STREAM,
                # and use SOCK_DGRAM only upon request

                # skip STREAM sockets if NLT_DGRAM is requested
                if ((link.sock.type == socket.SOCK_STREAM) and
                        (flags & NLT_DGRAM)):
                    continue

                # skip DGRAM sockets if NLT_DGRAM is not requested
                if ((link.sock.type == socket.SOCK_DGRAM) and
                        not (flags & NLT_DGRAM)):
                    continue

                # in any other case -- send packet
                self.remote[uid].gate(envelope, sock)

    def unmasq(self, nonce, envelope):
        target = self.masquerade[nonce]
        envelope['header']['sequence_number'] = \
            target.envelope['header']['sequence_number']
        envelope['header']['pid'] = \
            target.envelope['header']['pid']
        envelope.reset()
        envelope.encode()
        target.socket.send(envelope.buf.getvalue())

    def route_data(self, sock, envelope):
        nonce = envelope['header']['sequence_number']

        if envelope['dport'] in self.local:
            try:
                self.local[envelope['dport']].gate(envelope, sock)
            except:
                traceback.print_exc()

        elif nonce in self.masquerade:
            self.unmasq(nonce, envelope)

        else:
            # FIXME fix it, please, or kill with fire
            # there should be no data repack
            data = io.BytesIO(envelope.get_attr('IPR_ATTR_CDATA'))
            for cid, u32 in self.subscribe.items():
                self.filter_u32(u32, data)

    def filter_u32(self, u32, data):
        for offset, key, mask in u32['keys']:
            data.seek(offset)
            compare = struct.unpack('I', data.read(4))[0]
            if compare & mask != key:
                return
        # envelope data
        envelope = envmsg()
        envelope['header']['type'] = NLMSG_TRANSPORT
        envelope['attrs'] = [['IPR_ATTR_CDATA',
                              data.getvalue()]]
        envelope['id'] = uuid.uuid4().bytes
        envelope.encode()
        u32['socket'].send(envelope.buf.getvalue())

    def route_netlink(self, sock, raw):
        data = io.BytesIO()
        data.length = data.write(raw)
        data.seek(8)
        seq = struct.unpack('I', data.read(4))[0]

        # extract masq info
        target = self.masquerade.get(seq, None)
        if target is None:
            for cid, u32 in self.subscribe.items():
                self.filter_u32(u32, data)
        else:
            offset = 0
            while offset < data.length:
                data.seek(offset)
                (length,
                 mtype,
                 flags,
                 seq,
                 pid) = struct.unpack('IHHII', data.read(16))
                data.seek(offset + 8)
                data.write(struct.pack('II',
                                       target.data.nonce,
                                       target.data.pid))
                # skip to the next in chunk
                offset += length
            # envelope data
            envelope = envmsg()
            envelope['header']['sequence_number'] = \
                target.envelope['header']['sequence_number']
            envelope['header']['pid'] = \
                target.envelope['header']['pid']
            envelope['header']['type'] = NLMSG_TRANSPORT
            # envelope['dst'] = target.envelope['src']
            envelope['src'] = target.envelope['dst']
            envelope['ttl'] = 16
            envelope['id'] = uuid.uuid4().bytes
            envelope['dport'] = target.envelope['sport']
            envelope['sport'] = target.envelope['dport']
            envelope['attrs'] = [['IPR_ATTR_CDATA',
                                  data.getvalue()]]
            envelope.encode()
            # target
            target.socket.send(envelope.buf.getvalue())

    def route(self, sock, raw):
        """
        Route message
        """
        data = io.BytesIO()
        data.length = data.write(raw)

        if data.length == 0 and self.ioloop.unregister(sock):
            if sock in self.clients:
                self.remove_client(sock)
            else:
                self.deregister_link(fd=sock)
            return

        for envelope in self.marshal.parse(data, sock):
            if envelope['id'] in self.packet_ids:
                # drop duplicated packets
                continue
            else:
                # register packet id
                self.packet_ids[envelope['id']] = CacheRecord(None)

            if envelope['dst'] != self.addr:
                # FORWARD
                # a packet for a remote system
                self.route_forward(sock, envelope)
            else:
                # INPUT
                # a packet for a local system
                if ((envelope['header']['flags'] & NLT_CONTROL) and not
                        (envelope['header']['flags'] & NLT_RESPONSE)):
                    # control packets
                    self.route_control(sock, envelope)
                else:
                    # transport packets
                    self.route_data(sock, envelope)

    def gate_forward(self, envelope, sock):
        # 2. register way back
        nonce = self.nonces.alloc()
        masq = MasqRecord(sock)
        # copy envelope! original will be modified
        masq.add_envelope(envelope.copy())
        self.masquerade[nonce] = masq
        envelope['header']['sequence_number'] = nonce
        envelope['header']['pid'] = os.getpid()
        envelope.buf.seek(0)
        envelope.encode()
        # 3. return data
        return envelope.buf.getvalue()

    def gate_untag(self, envelope, sock):
        # 1. get data
        data = io.BytesIO(envelope.get_attr('IPR_ATTR_CDATA'))
        # 2. register way back
        nonce = self.nonces.alloc()
        masq = MasqRecord(sock)
        masq.add_envelope(envelope.copy())
        masq.add_data(data)
        self.masquerade[nonce] = masq
        data.seek(8)
        data.write(struct.pack('II', nonce, self.pid))
        # 3. return data
        return data.getvalue()

    def parse_control(self, data):
        data.seek(0)
        cmd = mgmtmsg(data)
        cmd.decode()
        return cmd

    def register_link(self, uid, port, sock,
                      established=False, remote=False):
        if not established:
            self._rlist.add(sock)

        link = Link(uid, port, sock, established, remote)
        self.links[uid] = link
        if remote:
            self.remote[uid] = link
        else:
            self.local[port] = link
        return link

    def deregister_link(self, uid=None, fd=None):
        if fd is not None:
            for (uid, link) in self.links.items():
                if link.sock == fd:
                    break

        link = self.links[uid]

        if not link.keep:
            link.sock.close()
            self._rlist.remove(link.sock)

        del self.links[link.uid]
        if link.remote:
            del self.remote[link.uid]
        else:
            del self.local[link.port]
        return link.sock

    def add_client(self, sock):
        '''
        Add a client connection. Should not be called
        manually, but only on a client connect.
        '''
        self._rlist.add(sock)
        self._wlist.add(sock)
        self.clients.add(sock)
        return sock

    def remove_client(self, sock):
        self._rlist.remove(sock)
        self._wlist.remove(sock)
        self.clients.remove(sock)
        sock.close()
        return sock

    def start(self):
        self._expire_thread.start()
        if self.standalone:
            self.ioloop.start()

    def shutdown(self):
        self._stop_event.set()
        for sock in self.servers:
            sock.close()
        # shutdown sequence
        self._expire_thread.join()
        if self.standalone:
            self.ioloop.shutdown()
            self.ioloop.join()
Exemplo n.º 3
0
class IOCore(object):

    marshal = None
    name = 'Core API'
    default_target = None

    def __init__(self,
                 debug=False,
                 timeout=3,
                 do_connect=False,
                 host=None,
                 key=None,
                 cert=None,
                 ca=None,
                 addr=None,
                 fork=False,
                 secret=None):
        addr = addr or uuid32()
        self._timeout = timeout
        self.default_broker = addr
        self.default_dport = 0
        self.uids = set()
        self.listeners = {}  # {nonce: Queue(), ...}
        self.callbacks = []  # [(predicate, callback, args), ...]
        self.debug = debug
        self.cid = None
        self.cmd_nonce = AddrPool(minaddr=0xf, maxaddr=0xfffe)
        self.nonce = AddrPool(minaddr=0xffff, maxaddr=0xffffffff)
        self.emarshal = MarshalEnv()
        self.save = None
        if self.marshal is not None:
            self.marshal.debug = debug
            self.marshal = self.marshal()
        self.buffers = Queue.Queue()
        self._mirror = False
        self.host = host

        self.ioloop = IOLoop()

        self._brs, self.bridge = pairPipeSockets()
        # To fork or not to fork?
        #
        # It depends on how you gonna use RT netlink sockets
        # in your application. Performance can also differ,
        # and sometimes forked broker can even speedup your
        # application -- keep in mind Python's GIL
        if fork:
            # Start the I/O broker in a separate process,
            # so you can use multiple RT netlink sockets in
            # one application -- it does matter, if your
            # application already uses some RT netlink
            # library and you want to smoothly try and
            # migrate to the pyroute2
            self.forked_broker = Process(target=self._start_broker,
                                         args=(fork, secret))
            self.forked_broker.start()
        else:
            # Start I/O broker as an embedded object, so
            # the RT netlink socket will be opened in the
            # same process. Technically, you can open
            # multiple RT netlink sockets within one process,
            # but only the first one will receive all the
            # answers -- so effectively only one socket per
            # process can be used.
            self.forked_broker = None
            self._start_broker(fork, secret)
        self.ioloop.start()
        self.ioloop.register(self.bridge, self._route, defer=True)
        if do_connect:
            path = urlparse.urlparse(host).path
            (self.default_link,
             self.default_peer) = self.connect(self.host, key, cert, ca)
            self.default_dport = self.discover(self.default_target or path,
                                               self.default_peer)

    def _start_broker(self, fork=False, secret=None):
        iobroker = IOBroker(addr=self.default_broker,
                            ioloop=None if fork else self.ioloop,
                            control=self._brs,
                            secret=secret)
        iobroker.start()
        if fork:
            iobroker._stop_event.wait()

    @debug
    def _route(self, sock, raw):
        data = io.BytesIO()
        data.length = data.write(raw)

        for envelope in self.emarshal.parse(data, sock):

            nonce = envelope['header']['sequence_number']
            flags = envelope['header']['flags']
            try:
                buf = io.BytesIO()
                buf.length = buf.write(envelope.get_attr('IPR_ATTR_CDATA'))
                buf.seek(0)
                if ((flags & NLT_CONTROL) and (flags & NLT_RESPONSE)):
                    msg = mgmtmsg(buf)
                    msg.decode()
                    self.listeners[nonce].put_nowait(msg)
                else:
                    self.parse(envelope, buf)
            except AttributeError:
                # now silently drop bad packet
                pass

    @debug
    def parse(self, envelope, data):

        if self.marshal is None:
            nonce = envelope['header']['sequence_number']
            if envelope['header']['flags'] & NLT_EXCEPTION:
                error = RuntimeError(data.getvalue())
                msgs = [{
                    'header': {
                        'sequence_number': nonce,
                        'type': 0,
                        'flags': 0,
                        'error': error
                    },
                    'data': None
                }]
            else:
                msgs = [{
                    'header': {
                        'sequence_number': nonce,
                        'type': 0,
                        'flags': 0,
                        'error': None
                    },
                    'data': data.getvalue()
                }]
        else:
            msgs = self.marshal.parse(data)

        for msg in msgs:
            try:
                key = msg['header']['sequence_number']
            except (TypeError, KeyError):
                key = 0

            # 8<--------------------------------------------------------------
            # message filtering
            # right now it is simply iterating callback list
            # .. _ioc-callbacks:
            skip = False

            for cr in self.callbacks:
                if cr[0](envelope, msg):
                    if cr[1](envelope, msg, *cr[2]) is not None:
                        skip = True

            if skip:
                continue

            # 8<--------------------------------------------------------------
            if key not in self.listeners:
                key = 0

            if self._mirror and (key != 0):
                # On Python 2.6 it can fail due to class fabrics
                # in nlmsg definitions, so parse it again. It should
                # not be much slower than copy.deepcopy()
                if getattr(msg, 'raw', None) is not None:
                    raw = io.BytesIO()
                    raw.length = raw.write(msg.raw)
                    new = self.marshal.parse(raw)[0]
                else:
                    new = copy.deepcopy(msg)
                self.listeners[0].put_nowait(new)

            if key in self.listeners:
                try:
                    self.listeners[key].put_nowait(msg)
                except Queue.Full:
                    # FIXME: log this
                    pass

    def command(self, cmd, attrs=[], expect=None, addr=None):
        addr = addr or self.default_broker
        msg = mgmtmsg(io.BytesIO())
        msg['cmd'] = cmd
        msg['attrs'] = attrs
        msg['header']['type'] = NLMSG_CONTROL
        msg.encode()
        rsp = self.request(msg.buf.getvalue(),
                           env_flags=NLT_CONTROL,
                           nonce_pool=self.cmd_nonce,
                           addr=addr)[0]
        if rsp['cmd'] != IPRCMD_ACK:
            raise RuntimeError(rsp.get_attr('IPR_ATTR_ERROR'))
        if expect is not None:
            if type(expect) not in (list, tuple):
                return rsp.get_attr(expect)
            else:
                ret = []
                for item in expect:
                    ret.append(rsp.get_attr(item))
                return ret
        else:
            return None

    def unregister(self, addr=None):
        return self.command(IPRCMD_UNREGISTER, addr=addr)

    def register(self, secret, addr=None):
        return self.command(IPRCMD_REGISTER, [['IPR_ATTR_SECRET', secret]],
                            addr=addr)

    def discover(self, url, addr=None):
        # .. _ioc-discover:
        return self.command(IPRCMD_DISCOVER, [['IPR_ATTR_HOST', url]],
                            expect='IPR_ATTR_ADDR',
                            addr=addr)

    def provide(self, url):
        self.command(IPRCMD_PROVIDE, [['IPR_ATTR_HOST', url]])
        return self.command(IPRCMD_CONNECT, [['IPR_ATTR_HOST', url]])

    def remove(self, url):
        return self.command(IPRCMD_REMOVE, [['IPR_ATTR_HOST', url]])

    def serve(self, url, key='', cert='', ca='', addr=None):
        return self.command(
            IPRCMD_SERVE,
            [['IPR_ATTR_HOST', url], ['IPR_ATTR_SSL_KEY', key],
             ['IPR_ATTR_SSL_CERT', cert], ['IPR_ATTR_SSL_CA', ca]],
            addr=addr)

    def shutdown(self, url, addr=None):
        return self.command(IPRCMD_SHUTDOWN, [['IPR_ATTR_HOST', url]],
                            addr=addr)

    def connect(self, host=None, key='', cert='', ca='', addr=None):
        host = host or self.host
        (uid, peer) = self.command(
            IPRCMD_CONNECT,
            [['IPR_ATTR_HOST', host], ['IPR_ATTR_SSL_KEY', key],
             ['IPR_ATTR_SSL_CERT', cert], ['IPR_ATTR_SSL_CA', ca]],
            expect=['IPR_ATTR_UUID', 'IPR_ATTR_ADDR'],
            addr=addr)
        self.uids.add((uid, addr))
        return uid, peer

    def disconnect(self, uid, addr=None):
        ret = self.command(IPRCMD_DISCONNECT, [['IPR_ATTR_UUID', uid]],
                           addr=addr)
        self.uids.remove((uid, addr))
        return ret

    def release(self):
        '''
        Shutdown all threads and release netlink sockets
        '''
        for (uid, addr) in tuple(self.uids):
            try:
                self.disconnect(uid, addr=addr)
            except Queue.Empty as e:
                if addr == self.default_broker:
                    raise e
        self.command(IPRCMD_STOP)
        if self.forked_broker:
            self.forked_broker.join()
        self.ioloop.shutdown()
        self.ioloop.join()

        self._brs.send(struct.pack('I', 4))
        self._brs.close()
        self.bridge.close()

    def mirror(self, operate=True):
        '''
        Turn message mirroring on/off. When it is 'on', all
        received messages will be copied (mirrored) into the
        default 0 queue.
        '''
        self._mirror = operate

    def monitor(self, operate=True):
        '''
        Create/destroy the default 0 queue. Netlink socket
        receives messages all the time, and there are many
        messages that are not replies. They are just
        generated by the kernel as a reflection of settings
        changes. To start receiving these messages, call
        Netlink.monitor(). They can be fetched by
        Netlink.get(0) or just Netlink.get().
        '''
        if operate and self.cid is None:
            self.listeners[0] = Queue.Queue(maxsize=_QUEUE_MAXSIZE)
            self.cid = self.command(
                IPRCMD_SUBSCRIBE,
                [['IPR_ATTR_KEY', {
                    'offset': 8,
                    'key': 0,
                    'mask': 0
                }]],
                expect='IPR_ATTR_CID')
        else:
            self.command(IPRCMD_UNSUBSCRIBE, [['IPR_ATTR_CID', self.cid]])
            self.cid = None
            del self.listeners[0]

    def register_callback(self,
                          callback,
                          predicate=lambda e, x: True,
                          args=None):
        '''
        Register a callback to run on a message arrival.

        Callback is the function that will be called with the
        message as the first argument. Predicate is the optional
        callable object, that returns True or False. Upon True,
        the callback will be called. Upon False it will not.
        Args is a list or tuple of arguments.

        Simplest example, assume ipr is the IPRoute() instance::

            # create a simplest callback that will print messages
            def cb(env, msg):
                print(msg)

            # register callback for any message:
            ipr.register_callback(cb)

        More complex example, with filtering::

            # Set object's attribute after the message key
            def cb(env, msg, obj):
                obj.some_attr = msg["some key"]

            # Register the callback only for the loopback device, index 1:
            ipr.register_callback(cb,
                                  lambda e, x: x.get('index', None) == 1,
                                  (self, ))

        Please note: you do **not** need to register the default 0 queue
        to invoke callbacks on broadcast messages. Callbacks are
        iterated **before** messages get enqueued.
        '''
        if args is None:
            args = []
        self.callbacks.append((predicate, callback, args))

    def unregister_callback(self, callback):
        '''
        Remove the first reference to the function from the callback
        register
        '''
        cb = tuple(self.callbacks)
        for cr in cb:
            if cr[1] == callback:
                self.callbacks.pop(cb.index(cr))
                return

    @debug
    def get(self,
            key=0,
            raw=False,
            timeout=None,
            terminate=None,
            nonce_pool=None):
        '''
        Get a message from a queue

        * key -- message queue number
        '''
        nonce_pool = nonce_pool or self.nonce
        queue = self.listeners[key]
        result = []
        e = None
        timeout = (timeout or self._timeout) if (key != 0) else 0xffff
        while True:
            # timeout should also be set to catch ctrl-c
            # Bug-Url: http://bugs.python.org/issue1360
            try:
                msg = queue.get(block=True, timeout=timeout)
            except Queue.Empty as x:
                e = x
                if key == 0:
                    continue
                else:
                    break

            if (terminate is not None) and terminate(msg):
                break

            # exceptions
            if msg['header'].get('error', None) is not None:
                e = msg['header']['error']

            # RPC
            if self.marshal is None:
                data = msg.get('data', msg)
            else:
                data = msg

            # Netlink
            if (msg['header']['type'] != NLMSG_DONE):
                result.append(data)

            # break the loop if any
            if (key == 0) or (e is not None):
                break

            # wait for NLMSG_DONE if NLM_F_MULTI
            if (terminate is None) and (
                (msg['header']['type'] == NLMSG_DONE) or
                (not msg['header']['flags'] & NLM_F_MULTI)):
                break

        if key != 0:
            # delete the queue
            del self.listeners[key]
            nonce_pool.free(key)
            # get remaining messages from the queue and
            # re-route them to queue 0 or drop
            while not queue.empty():
                msg = queue.get()
                if 0 in self.listeners:
                    self.listeners[0].put(msg)

        if e is not None:
            raise e

        return result

    @debug
    def push(self, host, msg, env_flags=None, nonce=0, cname=None):
        addr, port = host
        envelope = envmsg()
        envelope['header']['sequence_number'] = nonce
        envelope['header']['pid'] = os.getpid()
        envelope['header']['type'] = NLMSG_TRANSPORT
        if env_flags is not None:
            envelope['header']['flags'] = env_flags
        envelope['dst'] = addr
        envelope['src'] = self.default_broker
        envelope['dport'] = port
        envelope['ttl'] = 16
        envelope['id'] = uuid.uuid4().bytes
        envelope['attrs'] = [['IPR_ATTR_CDATA', msg]]
        if cname is not None:
            envelope['attrs'].append(['IPR_ATTR_CNAME', cname])
        envelope.encode()
        self.bridge.send(envelope.buf.getvalue())

    def request(self,
                msg,
                env_flags=0,
                addr=None,
                port=None,
                nonce=None,
                nonce_pool=None,
                cname=None,
                response_timeout=None,
                terminate=None):
        nonce_pool = nonce_pool or self.nonce
        nonce = nonce or nonce_pool.alloc()
        port = port or self.default_dport
        addr = addr or self.default_broker
        self.listeners[nonce] = Queue.Queue(maxsize=_QUEUE_MAXSIZE)
        self.push((addr, port), msg, env_flags, nonce, cname)
        return self.get(nonce,
                        nonce_pool=nonce_pool,
                        timeout=response_timeout,
                        terminate=terminate)