Ejemplo n.º 1
0
    def initdb(self, nl=None):
        '''
        Restart IPRoute channel, and create all the DB
        from scratch. Can be used when sync is lost.
        '''
        self.nl = nl or IPRoute()
        self.nl.monitor = True
        self.nl.bind(async=self._nl_async)

        # resolvers
        self.interfaces = Dotkeys()
        self.routes = RoutingTableSet(ipdb=self)
        self.by_name = View(src=self.interfaces,
                            constraint=lambda k, v: isinstance(k, basestring))
        self.by_index = View(src=self.interfaces,
                             constraint=lambda k, v: isinstance(k, int))

        # caches
        self.ipaddr = {}
        self.neighbours = {}

        # load information
        links = self.nl.get_links()
        for link in links:
            self.device_put(link, skip_slaves=True)
        for link in links:
            self.update_slaves(link)
        self.update_addr(self.nl.get_addr())
        self.update_neighbours(self.nl.get_neighbours())
        routes4 = self.nl.get_routes(family=AF_INET)
        routes6 = self.nl.get_routes(family=AF_INET6)
        self.update_routes(routes4)
        self.update_routes(routes6)
Ejemplo n.º 2
0
    def initdb(self, nl=None):
        self.nl = nl or IPRoute()
        self.mnl = self.nl.clone()

        # resolvers
        self.interfaces = TransactionalBase()
        self.routes = RoutingTableSet(ipdb=self,
                                      ignore_rtables=self._ignore_rtables)
        self.by_name = View(src=self.interfaces,
                            constraint=lambda k, v: isinstance(k, basestring))
        self.by_index = View(src=self.interfaces,
                             constraint=lambda k, v: isinstance(k, int))

        # caches
        self.ipaddr = {}
        self.neighbours = {}

        try:
            self.mnl.bind(async=self._nl_async)
            # load information
            links = self.nl.get_links()
            for link in links:
                self._interface_add(link, skip_slaves=True)
            for link in links:
                self.update_slaves(link)
            # bridge info
            links = self.nl.get_vlans()
            for link in links:
                self._interface_add(link)
            #
            for msg in self.nl.get_addr():
                self._addr_add(msg)
            for msg in self.nl.get_neighbours():
                self._neigh_add(msg)
            for msg in self.nl.get_routes(family=AF_INET):
                self._route_add(msg)
            for msg in self.nl.get_routes(family=AF_INET6):
                self._route_add(msg)
            for msg in self.nl.get_routes(family=AF_MPLS):
                self._route_add(msg)
        except Exception as e:
            logging.error('initdb error: %s', e)
            logging.error(traceback.format_exc())
            try:
                self.nl.close()
                self.mnl.close()
            except:
                pass
            raise e
Ejemplo n.º 3
0
    def initdb(self, nl=None):
        '''
        Restart IPRoute channel, and create all the DB
        from scratch. Can be used when sync is lost.
        '''
        self.nl = nl or IPRoute()

        # resolvers
        self.interfaces = Dotkeys()
        self.routes = RoutingTableSet(ipdb=self,
                                      ignore_rtables=self._ignore_rtables)
        self.by_name = View(src=self.interfaces,
                            constraint=lambda k, v: isinstance(k, basestring))
        self.by_index = View(src=self.interfaces,
                             constraint=lambda k, v: isinstance(k, int))

        # caches
        self.ipaddr = {}
        self.neighbours = {}

        try:
            self.nl.bind(async=self._nl_async)
            # load information
            links = self.nl.get_links()
            for link in links:
                self.device_put(link, skip_slaves=True)
            for link in links:
                self.update_slaves(link)
            # bridge info
            links = self.nl.get_vlans()
            for link in links:
                self.update_dev(link)
            #
            self.update_addr(self.nl.get_addr())
            self.update_neighbours(self.nl.get_neighbours())
            routes4 = self.nl.get_routes(family=AF_INET)
            routes6 = self.nl.get_routes(family=AF_INET6)
            self.update_routes(routes4)
            self.update_routes(routes6)
        except Exception as e:
            logging.error('initdb error: %s', e)
            logging.error(traceback.format_exc())
            try:
                self.nl.close()
            except:
                pass
            raise e
Ejemplo n.º 4
0
class IPDB(object):
    '''
    The class that maintains information about network setup
    of the host. Monitoring netlink events allows it to react
    immediately. It uses no polling.
    '''

    def __init__(self, nl=None, mode='implicit',
                 restart_on_error=None, nl_async=None,
                 debug=False, ignore_rtables=None):
        '''
        Parameters:
            - nl -- IPRoute() reference
            - mode -- (implicit, explicit, direct)
            - iclass -- the interface class type

        If you do not provide iproute instance, ipdb will
        start it automatically.
        '''
        self.mode = mode
        self.debug = debug
        if isinstance(ignore_rtables, int):
            self._ignore_rtables = [ignore_rtables, ]
        elif isinstance(ignore_rtables, (list, tuple, set)):
            self._ignore_rtables = ignore_rtables
        else:
            self._ignore_rtables = []
        self.iclass = Interface
        self._nl_async = config.ipdb_nl_async if nl_async is None else True
        self._stop = False
        # see also 'register_callback'
        self._post_callbacks = {}
        self._pre_callbacks = {}
        self._cb_threads = {}

        # locks and events
        self._links_event = threading.Event()
        self.exclusive = threading.RLock()
        self._shutdown_lock = threading.Lock()

        # load information
        self.restart_on_error = restart_on_error if \
            restart_on_error is not None else nl is None
        self.initdb(nl)

        # start monitoring thread
        self._mthread = threading.Thread(target=self.serve_forever)
        if hasattr(sys, 'ps1') and self.nl.__class__.__name__ != 'Client':
            self._mthread.setDaemon(True)
        self._mthread.start()
        #
        atexit.register(self.release)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def initdb(self, nl=None):
        '''
        Restart IPRoute channel, and create all the DB
        from scratch. Can be used when sync is lost.
        '''
        self.nl = nl or IPRoute()

        # resolvers
        self.interfaces = Dotkeys()
        self.routes = RoutingTableSet(ipdb=self,
                                      ignore_rtables=self._ignore_rtables)
        self.by_name = View(src=self.interfaces,
                            constraint=lambda k, v: isinstance(k, basestring))
        self.by_index = View(src=self.interfaces,
                             constraint=lambda k, v: isinstance(k, int))

        # caches
        self.ipaddr = {}
        self.neighbours = {}

        try:
            self.nl.bind(async=self._nl_async)
            # load information
            links = self.nl.get_links()
            for link in links:
                self.device_put(link, skip_slaves=True)
            for link in links:
                self.update_slaves(link)
            self.update_addr(self.nl.get_addr())
            self.update_neighbours(self.nl.get_neighbours())
            routes4 = self.nl.get_routes(family=AF_INET)
            routes6 = self.nl.get_routes(family=AF_INET6)
            self.update_routes(routes4)
            self.update_routes(routes6)
        except Exception as e:
            try:
                self.nl.close()
            except:
                pass
            raise e

    def register_callback(self, callback, mode='post'):
        '''
        IPDB callbacks are routines executed on a RT netlink
        message arrival. There are two types of callbacks:
        "post" and "pre" callbacks.

        ...

        "Post" callbacks are executed after the message is
        processed by IPDB and all corresponding objects are
        created or deleted. Using ipdb reference in "post"
        callbacks you will access the most up-to-date state
        of the IP database.

        "Post" callbacks are executed asynchronously in
        separate threads. These threads can work as long
        as you want them to. Callback threads are joined
        occasionally, so for a short time there can exist
        stopped threads.

        ...

        "Pre" callbacks are synchronous routines, executed
        before the message gets processed by IPDB. It gives
        you the way to patch arriving messages, but also
        places a restriction: until the callback exits, the
        main event IPDB loop is blocked.

        Normally, only "post" callbacks are required. But in
        some specific cases "pre" also can be useful.

        ...

        The routine, `register_callback()`, takes two arguments:
            - callback function
            - mode (optional, default="post")

        The callback should be a routine, that accepts three
        arguments::

            cb(ipdb, msg, action)

        Arguments are:

            - **ipdb** is a reference to IPDB instance, that invokes
                the callback.
            - **msg** is a message arrived
            - **action** is just a msg['event'] field

        E.g., to work on a new interface, you should catch
        action == 'RTM_NEWLINK' and with the interface index
        (arrived in msg['index']) get it from IPDB::

            index = msg['index']
            interface = ipdb.interfaces[index]
        '''
        lock = threading.Lock()

        def safe(*argv, **kwarg):
            with lock:
                callback(*argv, **kwarg)

        safe.hook = callback
        safe.lock = lock
        safe.uuid = uuid32()

        if mode == 'post':
            self._post_callbacks[safe.uuid] = safe
        elif mode == 'pre':
            self._pre_callbacks[safe.uuid] = safe
        return safe.uuid

    def unregister_callback(self, cuid, mode='post'):
        if mode == 'post':
            cbchain = self._post_callbacks
        elif mode == 'pre':
            cbchain = self._pre_callbacks
        else:
            raise KeyError('Unknown callback mode')
        safe = cbchain[cuid]
        with safe.lock:
            cbchain.pop(cuid)
        for t in tuple(self._cb_threads.get(cuid, ())):
            t.join(3)
        ret = self._cb_threads.get(cuid, ())
        return ret

    def release(self):
        '''
        Shutdown IPDB instance and sync the state. Since
        IPDB is asyncronous, some operations continue in the
        background, e.g. callbacks. So, prior to exit the
        script, it is required to properly shutdown IPDB.

        The shutdown sequence is not forced in an interactive
        python session, since it is easier for users and there
        is enough time to sync the state. But for the scripts
        the `release()` call is required.
        '''
        with self._shutdown_lock:
            if self._stop:
                return

            self._stop = True
            # collect all the callbacks
            for cuid in tuple(self._cb_threads):
                for t in tuple(self._cb_threads[cuid]):
                    t.join()
            # terminate the main loop
            try:
                self.nl.put({'index': 1}, RTM_GETLINK)
                self._mthread.join()
            except Exception:
                # Just give up.
                # We can not handle this case
                pass
            self.nl.close()
            self.nl = None

            # flush all the objects
            # -- interfaces
            for (key, dev) in self.by_name.items():
                self.detach(key, dev['index'], dev.nlmsg)
            # -- routes
            for key in tuple(self.routes.tables.keys()):
                del self.routes.tables[key]
            self.routes.tables[254] = None
            # -- ipaddr
            for key in tuple(self.ipaddr.keys()):
                del self.ipaddr[key]
            # -- neighbours
            for key in tuple(self.neighbours.keys()):
                del self.neighbours[key]

    def create(self, kind, ifname, reuse=False, **kwarg):
        '''
        Create an interface. Arguments 'kind' and 'ifname' are
        required.

            - kind — interface type, can be of:
                - bridge
                - bond
                - vlan
                - tun
                - dummy
                - veth
                - macvlan
                - macvtap
                - gre
                - team
                - ovs-bridge
            - ifname — interface name
            - reuse — if such interface exists, return it anyway

        Different interface kinds can require different
        arguments for creation.

        ► **veth**

        To properly create `veth` interface, one should specify
        `peer` also, since `veth` interfaces are created in pairs::

            with ip.create(ifname='v1p0', kind='veth', peer='v1p1') as i:
                i.add_ip('10.0.0.1/24')
                i.add_ip('10.0.0.2/24')

        The code above creates two interfaces, `v1p0` and `v1p1`, and
        adds two addresses to `v1p0`.

        ► **macvlan**

        Macvlan interfaces act like VLANs within OS. The macvlan driver
        provides an ability to add several MAC addresses on one interface,
        where every MAC address is reflected with a virtual interface in
        the system.

        In some setups macvlan interfaces can replace bridge interfaces,
        providing more simple and at the same time high-performance
        solution::

            ip.create(ifname='mvlan0',
                      kind='macvlan',
                      link=ip.interfaces.em1,
                      macvlan_mode='private').commit()

        Several macvlan modes are available: 'private', 'vepa', 'bridge',
        'passthru'. Ususally the default is 'vepa'.

        ► **macvtap**

        Almost the same as macvlan, but creates also a character tap device::

            ip.create(ifname='mvtap0',
                      kind='macvtap',
                      link=ip.interfaces.em1,
                      macvtap_mode='vepa').commit()

        Will create a device file `"/dev/tap%s" % ip.interfaces.mvtap0.index`

        ► **gre**

        Create GRE tunnel::

            with ip.create(ifname='grex',
                           kind='gre',
                           gre_local='172.16.0.1',
                           gre_remote='172.16.0.101',
                           gre_ttl=16) as i:
                i.add_ip('192.168.0.1/24')
                i.up()


        ► **vlan**

        VLAN interfaces require additional parameters, `vlan_id` and
        `link`, where `link` is a master interface to create VLAN on::

            ip.create(ifname='v100',
                      kind='vlan',
                      link=ip.interfaces.eth0,
                      vlan_id=100)

            ip.create(ifname='v100',
                      kind='vlan',
                      link=1,
                      vlan_id=100)

        The `link` parameter should be either integer, interface id, or
        an interface object. VLAN id must be integer.

        ► **vxlan**

        VXLAN interfaces are like VLAN ones, but require a bit more
        parameters::

            ip.create(ifname='vx101',
                      kind='vxlan',
                      vxlan_link=ip.interfaces.eth0,
                      vxlan_id=101,
                      vxlan_group='239.1.1.1',
                      vxlan_ttl=16)

        All possible vxlan parameters are listed in the module
        `pyroute2.netlink.rtnl.ifinfmsg:... vxlan_data`.

        ► **tuntap**

        Possible `tuntap` keywords:

            - `mode` — "tun" or "tap"
            - `uid` — integer
            - `gid` — integer
            - `ifr` — dict of tuntap flags (see tuntapmsg.py)
        '''
        with self.exclusive:
            # check for existing interface
            if ifname in self.interfaces:
                if (self.interfaces[ifname]['ipdb_scope'] == 'shadow') \
                        or reuse:
                    device = self.interfaces[ifname]
                    kwarg['kind'] = kind
                    device.load_dict(kwarg)
                    device.set_item('ipdb_scope', 'create')
                else:
                    raise CreateException("interface %s exists" %
                                          ifname)
            else:
                device = \
                    self.interfaces[ifname] = \
                    self.iclass(ipdb=self, mode='snapshot')
                device.update(kwarg)
                if isinstance(kwarg.get('link', None), Interface):
                    device['link'] = kwarg['link']['index']
                if isinstance(kwarg.get('vxlan_link', None), Interface):
                    device['vxlan_link'] = kwarg['vxlan_link']['index']
                device['kind'] = kind
                device['index'] = kwarg.get('index', 0)
                device['ifname'] = ifname
                device['ipdb_scope'] = 'create'
                device._mode = self.mode
            tid = device.begin()
        #
        # All the device methods are handled via `transactional.update()`
        # except of the very creation.
        #
        # Commit the changes in the 'direct' mode, since this call is not
        # decorated.
        if self.mode == 'direct':
            device.commit(tid)
        return device

    def commit(self, transactions=None, rollback=False):
        # what to commit: either from transactions argument, or from
        # started transactions on existing objects
        if transactions is None:
            # collect interface transactions
            txlist = [(x, x.last()) for x in self.by_name.values() if x._tids]
            # collect route transactions
            for table in self.routes.tables.keys():
                txlist.extend([(x, x.last()) for x in
                               self.routes.tables[table]
                               if x._tids])
            txlist = sorted(txlist,
                            key=lambda x: x[1]['ipdb_priority'],
                            reverse=True)
            transactions = txlist

        snapshots = []
        removed = []

        try:
            for (target, tx) in transactions:
                if target['ipdb_scope'] == 'detached':
                    continue
                if tx['ipdb_scope'] == 'remove':
                    tx['ipdb_scope'] = 'shadow'
                    removed.append((target, tx))
                if not rollback:
                    s = (target, target.pick(detached=True))
                    snapshots.append(s)
                target.commit(transaction=tx, rollback=rollback)
        except Exception:
            if not rollback:
                self.fallen = transactions
                self.commit(transactions=snapshots, rollback=True)
            raise
        else:
            if not rollback:
                for (target, tx) in removed:
                    target['ipdb_scope'] = 'detached'
                    target.detach()
        finally:
            if not rollback:
                for (target, tx) in transactions:
                    target.drop(tx)

    def device_del(self, msg):
        target = self.interfaces.get(msg['index'])
        if target is None:
            return
        target.nlmsg = msg
        # check for freezed devices
        if getattr(target, '_freeze', None):
            self.interfaces[msg['index']].set_item('ipdb_scope', 'shadow')
            return
        # check for locked devices
        if target.get('ipdb_scope') in ('locked', 'shadow'):
            self.interfaces[msg['index']].sync()
            return
        self.detach(None, msg['index'], msg)

    def device_put(self, msg, skip_slaves=False):
        # check, if a record exists
        index = msg.get('index', None)
        ifname = msg.get_attr('IFLA_IFNAME', None)
        # scenario #1: no matches for both: new interface
        # scenario #2: ifname exists, index doesn't: index changed
        # scenario #3: index exists, ifname doesn't: name changed
        # scenario #4: both exist: assume simple update and
        # an optional name change
        if ((index not in self.interfaces) and
                (ifname not in self.interfaces)):
            # scenario #1, new interface
            device = \
                self.interfaces[index] = \
                self.interfaces[ifname] = self.iclass(ipdb=self)
        elif ((index not in self.interfaces) and
                (ifname in self.interfaces)):
            # scenario #2, index change
            old_index = self.interfaces[ifname]['index']
            device = self.interfaces[index] = self.interfaces[ifname]
            if old_index in self.interfaces:
                del self.interfaces[old_index]
            if old_index in self.ipaddr:
                self.ipaddr[index] = self.ipaddr[old_index]
                del self.ipaddr[old_index]
            if old_index in self.neighbours:
                self.neighbours[index] = self.neighbours[old_index]
                del self.neighbours[old_index]
        else:
            # scenario #3, interface rename
            # scenario #4, assume rename
            old_name = self.interfaces[index]['ifname']
            if old_name != ifname:
                # unlink old name
                del self.interfaces[old_name]
            device = self.interfaces[ifname] = self.interfaces[index]

        if index not in self.ipaddr:
            # for interfaces, created by IPDB
            self.ipaddr[index] = IPaddrSet()

        if index not in self.neighbours:
            self.neighbours[index] = LinkedSet()

        device.load_netlink(msg)

        if not skip_slaves:
            self.update_slaves(msg)

    def detach(self, name, idx, msg=None):
        with self.exclusive:
            if msg is not None:
                try:
                    self.update_slaves(msg)
                except KeyError:
                    pass
                if msg['event'] == 'RTM_DELLINK' and \
                        msg['change'] != 0xffffffff:
                    return
            if idx is None or idx < 1:
                target = self.interfaces[name]
                idx = target['index']
            else:
                target = self.interfaces[idx]
                name = target['ifname']
            target.sync()
            self.interfaces.pop(name, None)
            self.interfaces.pop(idx, None)
            self.ipaddr.pop(idx, None)
            self.neighbours.pop(idx, None)
            target.set_item('ipdb_scope', 'detached')

    def watchdog(self, action='RTM_NEWLINK', **kwarg):
        return Watchdog(self, action, kwarg)

    def update_dev(self, dev):
        # ignore non-system updates on devices not
        # registered in the DB
        if (dev['index'] not in self.interfaces) and \
                (dev['change'] != 0xffffffff):
            return
        if dev['event'] == 'RTM_NEWLINK':
            self.device_put(dev)
        else:
            self.device_del(dev)

    def update_routes(self, routes):
        for msg in routes:
            self.routes.load_netlink(msg)

    def _lookup_master(self, msg):
        master = None
        # lookup for IFLA_INFO_OVS_MASTER
        li = msg.get_attr('IFLA_LINKINFO')
        if li:
            master = li.get_attr('IFLA_INFO_OVS_MASTER')
        # lookup for IFLA_MASTER
        if master is None:
            master = msg.get_attr('IFLA_MASTER')
        # pls keep in mind, that in the case of IFLA_MASTER
        # lookup is done via interface index, while in the case
        # of IFLA_INFO_OVS_MASTER lookup is done via ifname
        return self.interfaces.get(master, None)

    def update_slaves(self, msg):
        # Update slaves list -- only after update IPDB!

        master = self._lookup_master(msg)
        index = msg['index']
        # there IS a master for the interface
        if master is not None:
            if msg['event'] == 'RTM_NEWLINK':
                # TODO tags: ipdb
                # The code serves one particular case, when
                # an enslaved interface is set to belong to
                # another master. In this case there will be
                # no 'RTM_DELLINK', only 'RTM_NEWLINK', and
                # we can end up in a broken state, when two
                # masters refers to the same slave
                for device in self.by_index:
                    if index in self.interfaces[device]['ports']:
                        try:
                            self.interfaces[device].del_port(
                                index, direct=True)
                        except KeyError:
                            pass
                master.add_port(index, direct=True)
            elif msg['event'] == 'RTM_DELLINK':
                if index in master['ports']:
                    master.del_port(index, direct=True)
        # there is NO masters for the interface, clean them if any
        else:
            device = self.interfaces[msg['index']]

            # clean device from ports
            for master in self.by_index:
                if index in self.interfaces[master]['ports']:
                    try:
                        self.interfaces[master].del_port(
                            index, direct=True)
                    except KeyError:
                        pass
            master = device.if_master
            if master is not None:
                if 'master' in device:
                    device.del_item('master')
                if (master in self.interfaces) and \
                        (msg['index'] in self.interfaces[master].ports):
                    try:
                        self.interfaces[master].del_port(
                            msg['index'], direct=True)
                    except KeyError:
                        pass

    def update_addr(self, addrs, action='add'):
        # Update address list of an interface.

        for addr in addrs:
            nla = get_addr_nla(addr)
            if self.debug:
                raw = addr
            else:
                raw = {'local': addr.get_attr('IFA_LOCAL'),
                       'broadcast': addr.get_attr('IFA_BROADCAST'),
                       'address': addr.get_attr('IFA_ADDRESS'),
                       'flags': addr.get_attr('IFA_FLAGS'),
                       'prefixlen': addr.get('prefixlen')}
            if nla is not None:
                try:
                    method = getattr(self.ipaddr[addr['index']], action)
                    method(key=(nla, addr['prefixlen']), raw=raw)
                except:
                    pass

    def update_neighbours(self, neighs, action='add'):

        for neigh in neighs:
            nla = neigh.get_attr('NDA_DST')
            if self.debug:
                raw = neigh
            else:
                raw = {'lladdr': neigh.get_attr('NDA_LLADDR')}
            if nla is not None:
                try:
                    method = getattr(self.neighbours[neigh['ifindex']], action)
                    method(key=nla, raw=raw)
                except:
                    pass

    def serve_forever(self):
        '''
        Main monitoring cycle. It gets messages from the
        default iproute queue and updates objects in the
        database.

        .. note::
            Should not be called manually.
        '''
        while not self._stop:
            try:
                messages = self.nl.get()
                ##
                # Check it again
                #
                # NOTE: one should not run callbacks or
                # anything like that after setting the
                # _stop flag, since IPDB is not valid
                # anymore
                if self._stop:
                    break
            except:
                logging.error('Restarting IPDB instance after '
                              'error:\n%s', traceback.format_exc())
                if self.restart_on_error:
                    try:
                        self.initdb()
                    except:
                        logging.error('Error restarting DB:\n%s',
                                      traceback.format_exc())
                        return
                    continue
                else:
                    raise RuntimeError('Emergency shutdown')
            for msg in messages:
                # Run pre-callbacks
                # NOTE: pre-callbacks are synchronous
                for (cuid, cb) in tuple(self._pre_callbacks.items()):
                    try:
                        cb(self, msg, msg['event'])
                    except:
                        pass

                with self.exclusive:
                    # FIXME: refactor it to a dict
                    if msg.get('event', None) in ('RTM_NEWLINK',
                                                  'RTM_DELLINK'):
                        self.update_dev(msg)
                        self._links_event.set()
                    elif msg.get('event', None) == 'RTM_NEWADDR':
                        self.update_addr([msg], 'add')
                    elif msg.get('event', None) == 'RTM_DELADDR':
                        self.update_addr([msg], 'remove')
                    elif msg.get('event', None) == 'RTM_NEWNEIGH':
                        self.update_neighbours([msg], 'add')
                    elif msg.get('event', None) == 'RTM_DELNEIGH':
                        self.update_neighbours([msg], 'remove')
                    elif msg.get('event', None) in ('RTM_NEWROUTE',
                                                    'RTM_DELROUTE'):
                        self.update_routes([msg])

                # run post-callbacks
                # NOTE: post-callbacks are asynchronous
                for (cuid, cb) in tuple(self._post_callbacks.items()):
                    t = threading.Thread(name="callback %s" % (id(cb)),
                                         target=cb,
                                         args=(self, msg, msg['event']))
                    t.start()
                    if cuid not in self._cb_threads:
                        self._cb_threads[cuid] = set()
                    self._cb_threads[cuid].add(t)

                # occasionally join cb threads
                for cuid in tuple(self._cb_threads):
                    for t in tuple(self._cb_threads.get(cuid, ())):
                        t.join(0)
                        if not t.is_alive():
                            try:
                                self._cb_threads[cuid].remove(t)
                            except KeyError:
                                pass
                            if len(self._cb_threads.get(cuid, ())) == 0:
                                del self._cb_threads[cuid]
Ejemplo n.º 5
0
class IPDB(object):
    '''
    The class that maintains information about network setup
    of the host. Monitoring netlink events allows it to react
    immediately. It uses no polling.
    '''

    def __init__(self, nl=None, mode='implicit',
                 restart_on_error=None, nl_async=None,
                 debug=False, ignore_rtables=None):
        self.mode = mode
        self.debug = debug
        if isinstance(ignore_rtables, int):
            self._ignore_rtables = [ignore_rtables, ]
        elif isinstance(ignore_rtables, (list, tuple, set)):
            self._ignore_rtables = ignore_rtables
        else:
            self._ignore_rtables = []
        self.iclass = Interface
        self._nl_async = config.ipdb_nl_async if nl_async is None else True
        self._stop = False
        # see also 'register_callback'
        self._post_callbacks = {}
        self._pre_callbacks = {}
        self._cb_threads = {}

        # locks and events
        self.exclusive = threading.RLock()
        self._shutdown_lock = threading.Lock()

        # load information
        self.restart_on_error = restart_on_error if \
            restart_on_error is not None else nl is None
        self.initdb(nl)

        # start monitoring thread
        self._mthread = threading.Thread(target=self.serve_forever)
        self._mthread.setDaemon(True)
        self._mthread.start()
        #
        atexit.register(self.release)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def initdb(self, nl=None):
        self.nl = nl or IPRoute()
        self.mnl = self.nl.clone()

        # resolvers
        self.interfaces = TransactionalBase()
        self.routes = RoutingTableSet(ipdb=self,
                                      ignore_rtables=self._ignore_rtables)
        self.by_name = View(src=self.interfaces,
                            constraint=lambda k, v: isinstance(k, basestring))
        self.by_index = View(src=self.interfaces,
                             constraint=lambda k, v: isinstance(k, int))

        # caches
        self.ipaddr = {}
        self.neighbours = {}

        try:
            self.mnl.bind(async=self._nl_async)
            # load information
            links = self.nl.get_links()
            for link in links:
                self._interface_add(link, skip_slaves=True)
            for link in links:
                self.update_slaves(link)
            # bridge info
            links = self.nl.get_vlans()
            for link in links:
                self._interface_add(link)
            #
            for msg in self.nl.get_addr():
                self._addr_add(msg)
            for msg in self.nl.get_neighbours():
                self._neigh_add(msg)
            for msg in self.nl.get_routes(family=AF_INET):
                self._route_add(msg)
            for msg in self.nl.get_routes(family=AF_INET6):
                self._route_add(msg)
            for msg in self.nl.get_routes(family=AF_MPLS):
                self._route_add(msg)
        except Exception as e:
            logging.error('initdb error: %s', e)
            logging.error(traceback.format_exc())
            try:
                self.nl.close()
                self.mnl.close()
            except:
                pass
            raise e

    def register_callback(self, callback, mode='post'):
        '''
        IPDB callbacks are routines executed on a RT netlink
        message arrival. There are two types of callbacks:
        "post" and "pre" callbacks.

        ...

        "Post" callbacks are executed after the message is
        processed by IPDB and all corresponding objects are
        created or deleted. Using ipdb reference in "post"
        callbacks you will access the most up-to-date state
        of the IP database.

        "Post" callbacks are executed asynchronously in
        separate threads. These threads can work as long
        as you want them to. Callback threads are joined
        occasionally, so for a short time there can exist
        stopped threads.

        ...

        "Pre" callbacks are synchronous routines, executed
        before the message gets processed by IPDB. It gives
        you the way to patch arriving messages, but also
        places a restriction: until the callback exits, the
        main event IPDB loop is blocked.

        Normally, only "post" callbacks are required. But in
        some specific cases "pre" also can be useful.

        ...

        The routine, `register_callback()`, takes two arguments:
            - callback function
            - mode (optional, default="post")

        The callback should be a routine, that accepts three
        arguments::

            cb(ipdb, msg, action)

        Arguments are:

            - **ipdb** is a reference to IPDB instance, that invokes
                the callback.
            - **msg** is a message arrived
            - **action** is just a msg['event'] field

        E.g., to work on a new interface, you should catch
        action == 'RTM_NEWLINK' and with the interface index
        (arrived in msg['index']) get it from IPDB::

            index = msg['index']
            interface = ipdb.interfaces[index]
        '''
        lock = threading.Lock()

        def safe(*argv, **kwarg):
            with lock:
                callback(*argv, **kwarg)

        safe.hook = callback
        safe.lock = lock
        safe.uuid = uuid32()

        if mode == 'post':
            self._post_callbacks[safe.uuid] = safe
        elif mode == 'pre':
            self._pre_callbacks[safe.uuid] = safe
        return safe.uuid

    def unregister_callback(self, cuid, mode='post'):
        if mode == 'post':
            cbchain = self._post_callbacks
        elif mode == 'pre':
            cbchain = self._pre_callbacks
        else:
            raise KeyError('Unknown callback mode')
        safe = cbchain[cuid]
        with safe.lock:
            cbchain.pop(cuid)
        for t in tuple(self._cb_threads.get(cuid, ())):
            t.join(3)
        ret = self._cb_threads.get(cuid, ())
        return ret

    def release(self):
        '''
        Shutdown IPDB instance and sync the state. Since
        IPDB is asyncronous, some operations continue in the
        background, e.g. callbacks. So, prior to exit the
        script, it is required to properly shutdown IPDB.

        The shutdown sequence is not forced in an interactive
        python session, since it is easier for users and there
        is enough time to sync the state. But for the scripts
        the `release()` call is required.
        '''
        with self.exclusive:
            with self._shutdown_lock:
                if self._stop:
                    return

                self._stop = True
                # collect all the callbacks
                for cuid in tuple(self._cb_threads):
                    for t in tuple(self._cb_threads[cuid]):
                        t.join()
                # terminate the main loop
                try:
                    for t in range(3):
                        self.mnl.put({'index': 1}, RTM_GETLINK)
                        self._mthread.join(t)
                        if not self._mthread.is_alive():
                            break
                except Exception:
                    # Just give up.
                    # We can not handle this case
                    pass
                self.nl.close()
                self.nl = None
                self.mnl.close()
                self.mnl = None

                # flush all the objects
                for (key, dev) in self.by_name.items():
                    try:
                        self.detach(key, dev['index'], dev.nlmsg)
                    except KeyError:
                        pass

                def flush(idx):
                    for key in tuple(idx.keys()):
                        try:
                            del idx[key]
                        except KeyError:
                            pass
                idx_list = [self.routes.tables[x] for x
                            in self.routes.tables.keys()]
                idx_list.append(self.ipaddr)
                idx_list.append(self.neighbours)
                for idx in idx_list:
                    flush(idx)

    def create(self, kind, ifname, reuse=False, **kwarg):
        with self.exclusive:
            # check for existing interface
            if ifname in self.interfaces:
                if (self.interfaces[ifname]['ipdb_scope'] == 'shadow') \
                        or reuse:
                    device = self.interfaces[ifname]
                    kwarg['kind'] = kind
                    device.load_dict(kwarg)
                    if self.interfaces[ifname]['ipdb_scope'] == 'shadow':
                        with device._direct_state:
                            device['ipdb_scope'] = 'create'
                else:
                    raise CreateException("interface %s exists" %
                                          ifname)
            else:
                device = \
                    self.interfaces[ifname] = \
                    self.iclass(ipdb=self, mode='snapshot')
                device.update(kwarg)
                if isinstance(kwarg.get('link', None), Interface):
                    device['link'] = kwarg['link']['index']
                if isinstance(kwarg.get('vxlan_link', None), Interface):
                    device['vxlan_link'] = kwarg['vxlan_link']['index']
                device['kind'] = kind
                device['index'] = kwarg.get('index', 0)
                device['ifname'] = ifname
                device['ipdb_scope'] = 'create'
                device._mode = self.mode
            device.begin()
        return device

    def commit(self, transactions=None, rollback=False):
        # what to commit: either from transactions argument, or from
        # started transactions on existing objects
        if transactions is None:
            # collect interface transactions
            txlist = [(x, x.current_tx) for x
                      in self.by_name.values() if x.local_tx.values()]
            # collect route transactions
            for table in self.routes.tables.keys():
                txlist.extend([(x, x.current_tx) for x in
                               self.routes.tables[table]
                               if x.local_tx.values()])
            txlist = sorted(txlist,
                            key=lambda x: x[1]['ipdb_priority'],
                            reverse=True)
            transactions = txlist

        snapshots = []
        removed = []

        try:
            for (target, tx) in transactions:
                if target['ipdb_scope'] == 'detached':
                    continue
                if tx['ipdb_scope'] == 'remove':
                    tx['ipdb_scope'] = 'shadow'
                    removed.append((target, tx))
                if not rollback:
                    s = (target, target.pick(detached=True))
                    snapshots.append(s)
                target.commit(transaction=tx, rollback=rollback)
        except Exception:
            if not rollback:
                self.fallen = transactions
                self.commit(transactions=snapshots, rollback=True)
            raise
        else:
            if not rollback:
                for (target, tx) in removed:
                    target['ipdb_scope'] = 'detached'
                    target.detach()
        finally:
            if not rollback:
                for (target, tx) in transactions:
                    target.drop(tx.uid)

    def _interface_del(self, msg):
        target = self.interfaces.get(msg['index'])
        if target is None:
            return
        for record in self.routes.filter({'oif': msg['index']}):
            with record['route']._direct_state:
                record['route']['ipdb_scope'] = 'gc'
        for record in self.routes.filter({'iif': msg['index']}):
            with record['route']._direct_state:
                record['route']['ipdb_scope'] = 'gc'
        target.nlmsg = msg
        # check for freezed devices
        if getattr(target, '_freeze', None):
            with target._direct_state:
                target['ipdb_scope'] = 'shadow'
            return
        # check for locked devices
        if target.get('ipdb_scope') in ('locked', 'shadow'):
            return
        self.detach(None, msg['index'], msg)

    def _interface_add(self, msg, skip_slaves=False):
        # check, if a record exists
        index = msg.get('index', None)
        ifname = msg.get_attr('IFLA_IFNAME', None)
        # scenario #1: no matches for both: new interface
        # scenario #2: ifname exists, index doesn't: index changed
        # scenario #3: index exists, ifname doesn't: name changed
        # scenario #4: both exist: assume simple update and
        # an optional name change
        if ((index not in self.interfaces) and
                (ifname not in self.interfaces)):
            # scenario #1, new interface
            device = \
                self.interfaces[index] = \
                self.interfaces[ifname] = self.iclass(ipdb=self)
        elif ((index not in self.interfaces) and
                (ifname in self.interfaces)):
            # scenario #2, index change
            old_index = self.interfaces[ifname]['index']
            device = self.interfaces[index] = self.interfaces[ifname]
            if old_index in self.interfaces:
                del self.interfaces[old_index]
            if old_index in self.ipaddr:
                self.ipaddr[index] = self.ipaddr[old_index]
                del self.ipaddr[old_index]
            if old_index in self.neighbours:
                self.neighbours[index] = self.neighbours[old_index]
                del self.neighbours[old_index]
        else:
            # scenario #3, interface rename
            # scenario #4, assume rename
            old_name = self.interfaces[index]['ifname']
            if old_name != ifname:
                # unlink old name
                del self.interfaces[old_name]
            device = self.interfaces[ifname] = self.interfaces[index]

        if index not in self.ipaddr:
            # for interfaces, created by IPDB
            self.ipaddr[index] = IPaddrSet()

        if index not in self.neighbours:
            self.neighbours[index] = LinkedSet()

        device.load_netlink(msg)

        if not skip_slaves:
            self.update_slaves(msg)

    def detach(self, name, idx, msg=None):
        with self.exclusive:
            if msg is not None:
                try:
                    self.update_slaves(msg)
                except KeyError:
                    pass
                if msg['event'] == 'RTM_DELLINK' and \
                        msg['change'] != 0xffffffff:
                    return
            if idx is None or idx < 1:
                target = self.interfaces[name]
                idx = target['index']
            else:
                target = self.interfaces[idx]
                name = target['ifname']
            self.interfaces.pop(name, None)
            self.interfaces.pop(idx, None)
            self.ipaddr.pop(idx, None)
            self.neighbours.pop(idx, None)
            with target._direct_state:
                target['ipdb_scope'] = 'detached'

    def watchdog(self, action='RTM_NEWLINK', **kwarg):
        return Watchdog(self, action, kwarg)

    def _route_add(self, msg):
        self.routes.load_netlink(msg)

    def _route_del(self, msg):
        self.routes.load_netlink(msg)

    def update_slaves(self, msg):
        # Update slaves list -- only after update IPDB!

        index = msg['index']
        master_index = msg.get_attr('IFLA_MASTER')
        if index == master_index:
            # one special case: links() call with AF_BRIDGE
            # returns IFLA_MASTER == index
            return
        master = self.interfaces.get(master_index, None)
        # there IS a master for the interface
        if master is not None:
            if msg['event'] == 'RTM_NEWLINK':
                # TODO tags: ipdb
                # The code serves one particular case, when
                # an enslaved interface is set to belong to
                # another master. In this case there will be
                # no 'RTM_DELLINK', only 'RTM_NEWLINK', and
                # we can end up in a broken state, when two
                # masters refers to the same slave
                for device in self.by_index:
                    if index in self.interfaces[device]['ports']:
                        try:
                            with self.interfaces[device]._direct_state:
                                self.interfaces[device].del_port(index)
                        except KeyError:
                            pass
                with master._direct_state:
                    master.add_port(index)
            elif msg['event'] == 'RTM_DELLINK':
                if index in master['ports']:
                    with master._direct_state:
                        master.del_port(index)
        # there is NO masters for the interface, clean them if any
        else:
            device = self.interfaces[msg['index']]
            # clean vlan list from the port
            for vlan in tuple(device['vlans']):
                with device._direct_state:
                    device.del_vlan(vlan)
            # clean device from ports
            for master in self.by_index:
                if index in self.interfaces[master]['ports']:
                    try:
                        with self.interfaces[master]._direct_state:
                            self.interfaces[master].del_port(index)
                    except KeyError:
                        pass
            master = device.if_master
            if master is not None:
                if 'master' in device:
                    with device._direct_state:
                        device['master'] = None
                if (master in self.interfaces) and \
                        (msg['index'] in self.interfaces[master]['ports']):
                    try:
                        with self.interfaces[master]._direct_state:
                            self.interfaces[master].del_port(index)
                    except KeyError:
                        pass

    def _addr_add(self, msg):
        if msg['family'] == AF_INET:
            addr = msg.get_attr('IFA_LOCAL')
        elif msg['family'] == AF_INET6:
            addr = msg.get_attr('IFA_ADDRESS')
        else:
            return
        raw = {'local': msg.get_attr('IFA_LOCAL'),
               'broadcast': msg.get_attr('IFA_BROADCAST'),
               'address': msg.get_attr('IFA_ADDRESS'),
               'flags': msg.get_attr('IFA_FLAGS'),
               'prefixlen': msg['prefixlen']}
        try:
            self.ipaddr[msg['index']].add(key=(addr, raw['prefixlen']),
                                          raw=raw)
        except:
            pass

    def _addr_del(self, msg):
        if msg['family'] == AF_INET:
            addr = msg.get_attr('IFA_LOCAL')
        elif msg['family'] == AF_INET6:
            addr = msg.get_attr('IFA_ADDRESS')
        else:
            return
        try:
            self.ipaddr[msg['index']].remove((addr, msg['prefixlen']))
        except:
            pass

    def _neigh_add(self, msg):
        if msg['family'] == AF_BRIDGE:
            return

        try:
            (self
             .neighbours[msg['ifindex']]
             .add(key=msg.get_attr('NDA_DST'),
                  raw={'lladdr': msg.get_attr('MDA_LLADDR')}))
        except:
            pass

    def _neigh_del(self, msg):
        if msg['family'] == AF_BRIDGE:
            return
        try:
            (self
             .neighbours[msg['ifindex']]
             .remove(msg.get_attr('NDA_DST')))
        except:
            pass

    def serve_forever(self):
        ###
        # Main monitoring cycle. It gets messages from the
        # default iproute queue and updates objects in the
        # database.
        #
        # Should not be called manually.
        ###
        event_map = {'RTM_NEWLINK': self._interface_add,
                     'RTM_DELLINK': self._interface_del,
                     'RTM_NEWADDR': self._addr_add,
                     'RTM_DELADDR': self._addr_del,
                     'RTM_NEWNEIGH': self._neigh_add,
                     'RTM_DELNEIGH': self._neigh_del,
                     'RTM_NEWROUTE': self._route_add,
                     'RTM_DELROUTE': self._route_del}
        while not self._stop:
            try:
                messages = self.mnl.get()
                ##
                # Check it again
                #
                # NOTE: one should not run callbacks or
                # anything like that after setting the
                # _stop flag, since IPDB is not valid
                # anymore
                if self._stop:
                    break
            except:
                logging.error('Restarting IPDB instance after '
                              'error:\n%s', traceback.format_exc())
                if self.restart_on_error:
                    try:
                        self.initdb()
                    except:
                        logging.error('Error restarting DB:\n%s',
                                      traceback.format_exc())
                        return
                    continue
                else:
                    raise RuntimeError('Emergency shutdown')
            for msg in messages:
                # Run pre-callbacks
                # NOTE: pre-callbacks are synchronous
                for (cuid, cb) in tuple(self._pre_callbacks.items()):
                    try:
                        cb(self, msg, msg['event'])
                    except:
                        pass

                with self.exclusive:
                    event = msg.get('event', None)
                    if event in event_map:
                        event_map[event](msg)

                # run post-callbacks
                # NOTE: post-callbacks are asynchronous
                for (cuid, cb) in tuple(self._post_callbacks.items()):
                    t = threading.Thread(name="callback %s" % (id(cb)),
                                         target=cb,
                                         args=(self, msg, msg['event']))
                    t.start()
                    if cuid not in self._cb_threads:
                        self._cb_threads[cuid] = set()
                    self._cb_threads[cuid].add(t)

                # occasionally join cb threads
                for cuid in tuple(self._cb_threads):
                    for t in tuple(self._cb_threads.get(cuid, ())):
                        t.join(0)
                        if not t.is_alive():
                            try:
                                self._cb_threads[cuid].remove(t)
                            except KeyError:
                                pass
                            if len(self._cb_threads.get(cuid, ())) == 0:
                                del self._cb_threads[cuid]