Esempio n. 1
0
class IPDB(object):
    '''
    The class that maintains information about network setup
    of the host. Monitoring netlink events allows it to react
    immediately. It uses no polling.
    '''

    def __init__(self, nl=None, mode='implicit',
                 restart_on_error=None, nl_async=None,
                 sndbuf=1048576, rcvbuf=1048576,
                 nl_bind_groups=RTMGRP_DEFAULTS,
                 ignore_rtables=None, callbacks=None,
                 sort_addresses=False, plugins=None):
        plugins = plugins or ['interfaces', 'routes', 'rules']
        pmap = {'interfaces': interfaces,
                'routes': routes,
                'rules': rules}
        self.mode = mode
        self.txdrop = False
        self._stdout = sys.stdout
        self._ipaddr_set = SortedIPaddrSet if sort_addresses else IPaddrSet
        self._event_map = {}
        self._deferred = {}
        self._ensure = []
        self._loaded = set()
        self._mthread = None
        self._nl_own = nl is None
        self._nl_async = config.ipdb_nl_async if nl_async is None else True
        self.mnl = None
        self.nl = nl
        self._sndbuf = sndbuf
        self._rcvbuf = rcvbuf
        self.nl_bind_groups = nl_bind_groups
        self._plugins = [pmap[x] for x in plugins if x in pmap]
        if isinstance(ignore_rtables, int):
            self._ignore_rtables = [ignore_rtables, ]
        elif isinstance(ignore_rtables, (list, tuple, set)):
            self._ignore_rtables = ignore_rtables
        else:
            self._ignore_rtables = []
        self._stop = False
        # see also 'register_callback'
        self._post_callbacks = {}
        self._pre_callbacks = {}

        # local event queues
        # - callbacks event queue
        self._cbq = queue.Queue(maxsize=8192)
        self._cbq_drop = 0
        # - users event queue
        self._evq = None
        self._evq_lock = threading.Lock()
        self._evq_drop = 0

        # locks and events
        self.exclusive = threading.RLock()
        self._shutdown_lock = threading.Lock()

        # register callbacks
        #
        # examples::
        #   def cb1(ipdb, msg, event):
        #       print(event, msg)
        #   def cb2(...):
        #       ...
        #
        #   # default mode: post
        #   IPDB(callbacks=[cb1, cb2])
        #   # specify the mode explicitly
        #   IPDB(callbacks=[(cb1, 'pre'), (cb2, 'post')])
        #
        for cba in callbacks or []:
            if not isinstance(cba, (tuple, list, set)):
                cba = (cba, )
            self.register_callback(*cba)

        # load information
        self.restart_on_error = restart_on_error if \
            restart_on_error is not None else nl is None

        # init the database
        self.initdb()

        # init the dir() cache
        self.__dir_cache__ = [i for i in self.__class__.__dict__.keys()
                              if i[0] != '_']
        self.__dir_cache__.extend(list(self._deferred.keys()))

        def cleanup(ref):
            ipdb_obj = ref()
            if ipdb_obj is not None:
                ipdb_obj.release()
        atexit.register(cleanup, weakref.ref(self))

    def __dir__(self):
        return self.__dir_cache__

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def _flush_db(self):

        def flush(idx):
            for key in tuple(idx.keys()):
                try:
                    del idx[key]
                except KeyError:
                    pass
        idx_list = []
        if 'interfaces' in self._loaded:
            for (key, dev) in self.by_name.items():
                try:
                    # FIXME
                    self.interfaces._detach(key,
                                            dev['index'],
                                            dev.nlmsg)
                except KeyError:
                    pass
            idx_list.append(self.ipaddr)
            idx_list.append(self.neighbours)
        if 'routes' in self._loaded:
            idx_list.extend([self.routes.tables[x] for x
                             in self.routes.tables.keys()])
        if 'rules' in self._loaded:
            idx_list.append(self.rules)
        for idx in idx_list:
            flush(idx)

    def initdb(self):

        # flush all the DB objects
        with self.exclusive:

            # explicitly cleanup object references
            for event in tuple(self._event_map):
                del self._event_map[event]

            self._flush_db()

            # if the command socket is not provided, create it
            if self._nl_own:
                if self.nl is not None:
                    self.nl.close()
                self.nl = IPRoute(sndbuf=self._sndbuf, rcvbuf=self._rcvbuf)
            # setup monitoring socket
            if self.mnl is not None:
                self._flush_mnl()
                self.mnl.close()
            self.mnl = self.nl.clone()
            try:
                self.mnl.bind(groups=self.nl_bind_groups,
                              async_cache=self._nl_async)
            except:
                self.mnl.close()
                if self._nl_own is None:
                    self.nl.close()
                raise

            # explicitly cleanup references
            for key in tuple(self._deferred):
                del self._deferred[key]

            for module in self._plugins:
                if (module.groups & self.nl_bind_groups) != module.groups:
                    continue
                for plugin in module.spec:
                    self._deferred[plugin['name']] = module.spec
                    if plugin['name'] in self._loaded:
                        delattr(self, plugin['name'])
                        self._loaded.remove(plugin['name'])

            # start service threads
            for tspec in (('_mthread', '_serve_main', 'IPDB main event loop'),
                          ('_cthread', '_serve_cb', 'IPDB cb event loop')):
                tg = getattr(self, tspec[0], None)
                if not getattr(tg, 'is_alive', lambda: False)():
                    tx = threading.Thread(name=tspec[2],
                                          target=getattr(self, tspec[1]))
                    setattr(self, tspec[0], tx)
                    tx.setDaemon(True)
                    tx.start()

    def __getattribute__(self, name):
        deferred = super(IPDB, self).__getattribute__('_deferred')
        if name in deferred:
            register = []
            spec = deferred[name]
            for plugin in spec:
                obj = plugin['class'](self, **plugin['kwarg'])
                setattr(self, plugin['name'], obj)
                register.append(obj)
                self._loaded.add(plugin['name'])
                del deferred[plugin['name']]
            for obj in register:
                if hasattr(obj, '_register'):
                    obj._register()
                if hasattr(obj, '_event_map'):
                    for event in obj._event_map:
                        if event not in self._event_map:
                            self._event_map[event] = []
                        self._event_map[event].append(obj._event_map[event])
        return super(IPDB, self).__getattribute__(name)

    def register_callback(self, callback, mode='post'):
        '''
        IPDB callbacks are routines executed on a RT netlink
        message arrival. There are two types of callbacks:
        "post" and "pre" callbacks.

        ...

        "Post" callbacks are executed after the message is
        processed by IPDB and all corresponding objects are
        created or deleted. Using ipdb reference in "post"
        callbacks you will access the most up-to-date state
        of the IP database.

        "Post" callbacks are executed asynchronously in
        separate threads. These threads can work as long
        as you want them to. Callback threads are joined
        occasionally, so for a short time there can exist
        stopped threads.

        ...

        "Pre" callbacks are synchronous routines, executed
        before the message gets processed by IPDB. It gives
        you the way to patch arriving messages, but also
        places a restriction: until the callback exits, the
        main event IPDB loop is blocked.

        Normally, only "post" callbacks are required. But in
        some specific cases "pre" also can be useful.

        ...

        The routine, `register_callback()`, takes two arguments:
            - callback function
            - mode (optional, default="post")

        The callback should be a routine, that accepts three
        arguments::

            cb(ipdb, msg, action)

        Arguments are:

            - **ipdb** is a reference to IPDB instance, that invokes
                the callback.
            - **msg** is a message arrived
            - **action** is just a msg['event'] field

        E.g., to work on a new interface, you should catch
        action == 'RTM_NEWLINK' and with the interface index
        (arrived in msg['index']) get it from IPDB::

            index = msg['index']
            interface = ipdb.interfaces[index]
        '''
        lock = threading.Lock()

        def safe(*argv, **kwarg):
            with lock:
                callback(*argv, **kwarg)

        safe.hook = callback
        safe.lock = lock
        safe.uuid = uuid32()

        if mode == 'post':
            self._post_callbacks[safe.uuid] = safe
        elif mode == 'pre':
            self._pre_callbacks[safe.uuid] = safe
        else:
            raise KeyError('Unknown callback mode')
        return safe.uuid

    def unregister_callback(self, cuid, mode='post'):
        if mode == 'post':
            cbchain = self._post_callbacks
        elif mode == 'pre':
            cbchain = self._pre_callbacks
        else:
            raise KeyError('Unknown callback mode')
        safe = cbchain[cuid]
        with safe.lock:
            ret = cbchain.pop(cuid)
        return ret

    def eventqueue(self, qsize=8192, block=True, timeout=None):
        '''
        Initializes event queue and returns event queue context manager.
        Once the context manager is initialized, events start to be collected,
        so it is possible to read initial state from the system witout losing
        last moment changes, and once that is done, start processing events.

        Example:

            ipdb = IPDB()
            with ipdb.eventqueue() as evq:
                my_state = ipdb.<needed_attribute>...
                for msg in evq:
                    update_state_by_msg(my_state, msg)
        '''
        return _evq_context(self, qsize, block, timeout)

    def eventloop(self, qsize=8192, block=True, timeout=None):
        """
        Event generator for simple cases when there is no need for initial
        state setup. Initialize event queue and yield events as they happen.
        """
        with self.eventqueue(qsize=qsize, block=block, timeout=timeout) as evq:
            for msg in evq:
                yield msg

    def release(self):
        '''
        Shutdown IPDB instance and sync the state. Since
        IPDB is asyncronous, some operations continue in the
        background, e.g. callbacks. So, prior to exit the
        script, it is required to properly shutdown IPDB.

        The shutdown sequence is not forced in an interactive
        python session, since it is easier for users and there
        is enough time to sync the state. But for the scripts
        the `release()` call is required.
        '''
        with self._shutdown_lock:
            if self._stop:
                log.warning("shutdown in progress")
                return
            self._stop = True
            self._cbq.put(ShutdownException("shutdown"))

            if self._mthread is not None:
                self._flush_mnl()
                self._mthread.join()

            if self.mnl is not None:
                self.mnl.close()
                self.mnl = None

            if self._nl_own:
                self.nl.close()
                self.nl = None

            self._flush_db()

    def _flush_mnl(self):
        if self.mnl is not None:
            # terminate the main loop
            for t in range(3):
                try:
                    msg = ifinfmsg()
                    msg['index'] = 1
                    msg.reset()
                    self.mnl.put(msg, RTM_GETLINK)
                except Exception as e:
                    log.error("shutdown error: %s", e)
                    # Just give up.
                    # We can not handle this case

    def create(self, kind, ifname, reuse=False, **kwarg):
        return self.interfaces.add(kind, ifname, reuse, **kwarg)

    def ensure(self, cmd='add', reachable=None, condition=None):
        if cmd == 'reset':
            self._ensure = []
        elif cmd == 'run':
            for f in self._ensure:
                f()
        elif cmd == 'add':
            if isinstance(reachable, basestring):
                reachable = reachable.split(':')
                if len(reachable) == 1:
                    f = partial(test_reachable_icmp, reachable[0])
                else:
                    raise NotImplementedError()
                self._ensure.append(f)
            else:
                if sys.stdin.isatty():
                    pprint(self._ensure, stream=self._stdout)
        elif cmd == 'print':
            pprint(self._ensure, stream=self._stdout)
        elif cmd == 'get':
            return self._ensure
        else:
            raise NotImplementedError()

    def items(self):
        # TODO: add support for filters?

        # iterate interfaces
        for ifname in getattr(self, 'by_name', {}):
            yield (('interfaces', ifname), self.interfaces[ifname])

        # iterate routes
        for table in getattr(getattr(self, 'routes', None),
                             'tables', {}):
            for key, route in self.routes.tables[table].items():
                yield (('routes', table, key), route)

    def dump(self):
        ret = {}
        for key, obj in self.items():
            ptr = ret
            for step in key[:-1]:
                if step not in ptr:
                    ptr[step] = {}
                ptr = ptr[step]
            ptr[key[-1]] = obj
        return ret

    def load(self, config, ptr=None):
        if ptr is None:
            ptr = self

        for key in config:
            obj = getattr(ptr, key, None)
            if obj is not None:
                if hasattr(obj, 'load'):
                    obj.load(config[key])
                else:
                    self.load(config[key], ptr=obj)
            elif hasattr(ptr, 'add'):
                ptr.add(**config[key])

        return self

    def review(self):
        ret = {}
        for key, obj in self.items():
            ptr = ret
            try:
                rev = obj.review()
            except TypeError:
                continue

            for step in key[:-1]:
                if step not in ptr:
                    ptr[step] = {}
                ptr = ptr[step]
            ptr[key[-1]] = rev

        if not ret:
            raise TypeError('no transaction started')
        return ret

    def drop(self):
        ok = False
        for key, obj in self.items():
            try:
                obj.drop()
            except TypeError:
                continue
            ok = True
        if not ok:
            raise TypeError('no transaction started')

    def commit(self, transactions=None, phase=1):
        # what to commit: either from transactions argument, or from
        # started transactions on existing objects
        if transactions is None:
            # collect interface transactions
            txlist = [(x, x.current_tx) for x
                      in getattr(self, 'by_name', {}).values()
                      if x.local_tx.values()]
            # collect route transactions
            for table in getattr(getattr(self, 'routes', None),
                                 'tables', {}).keys():
                txlist.extend([(x, x.current_tx) for x in
                               self.routes.tables[table]
                               if x.local_tx.values()])
            transactions = txlist

        snapshots = []
        removed = []

        tx_ipdb_prio = []
        tx_main = []
        tx_prio1 = []
        tx_prio2 = []
        tx_prio3 = []
        for (target, tx) in transactions:
            # 8<------------------------------
            # first -- explicit priorities
            if tx['ipdb_priority']:
                tx_ipdb_prio.append((target, tx))
                continue
            # 8<------------------------------
            # routes
            if isinstance(target, BaseRoute):
                tx_prio3.append((target, tx))
                continue
            # 8<------------------------------
            # intefaces
            kind = target.get('kind', None)
            if kind in ('vlan', 'vxlan', 'gre', 'tuntap', 'vti', 'vti6',
                        'vrf'):
                tx_prio1.append((target, tx))
            elif kind in ('bridge', 'bond'):
                tx_prio2.append((target, tx))
            else:
                tx_main.append((target, tx))
            # 8<------------------------------

        # explicitly sorted transactions
        tx_ipdb_prio = sorted(tx_ipdb_prio,
                              key=lambda x: x[1]['ipdb_priority'],
                              reverse=True)

        # FIXME: this should be documented
        #
        # The final transactions order:
        # 1. any txs with ipdb_priority (sorted by that field)
        #
        # Then come default priorities (no ipdb_priority specified):
        # 2. all the rest
        # 3. vlan, vxlan, gre, tuntap, vti, vrf
        # 4. bridge, bond
        # 5. routes
        transactions = tx_ipdb_prio + tx_main + tx_prio1 + tx_prio2 + tx_prio3

        try:
            for (target, tx) in transactions:
                if target['ipdb_scope'] == 'detached':
                    continue
                if tx['ipdb_scope'] == 'remove':
                    tx['ipdb_scope'] = 'shadow'
                    removed.append((target, tx))
                if phase == 1:
                    s = (target, target.pick(detached=True))
                    snapshots.append(s)
                # apply the changes, but NO rollback -- only phase 1
                target.commit(transaction=tx,
                              commit_phase=phase,
                              commit_mask=phase)
                # if the commit above fails, the next code
                # branch will run rollbacks
        except Exception:
            if phase == 1:
                # run rollbacks for ALL the collected transactions,
                # even successful ones
                self.fallen = transactions
                txs = filter(lambda x: not ('create' ==
                                            x[0]['ipdb_scope'] ==
                                            x[1]['ipdb_scope']), snapshots)
                self.commit(transactions=txs, phase=2)
            raise
        else:
            if phase == 1:
                for (target, tx) in removed:
                    target['ipdb_scope'] = 'detached'
                    target.detach()
        finally:
            if phase == 1:
                for (target, tx) in transactions:
                    target.drop(tx.uid)

        return self

    def watchdog(self, wdops='RTM_NEWLINK', **kwarg):
        return Watchdog(self, wdops, kwarg)

    def _serve_cb(self):
        ###
        # Callbacks thread working on a dedicated event queue.
        ###

        while not self._stop:
            msg = self._cbq.get()
            self._cbq.task_done()
            if isinstance(msg, ShutdownException):
                return
            elif isinstance(msg, Exception):
                raise msg
            for cb in tuple(self._post_callbacks.values()):
                try:
                    cb(self, msg, msg['event'])
                except:
                    pass

    def _serve_main(self):
        ###
        # Main monitoring cycle. It gets messages from the
        # default iproute queue and updates objects in the
        # database.
        ###

        while not self._stop:
            try:
                messages = self.mnl.get()
                ##
                # Check it again
                #
                # NOTE: one should not run callbacks or
                # anything like that after setting the
                # _stop flag, since IPDB is not valid
                # anymore
                if self._stop:
                    break
            except Exception as e:
                with self.exclusive:
                    if self._evq:
                        self._evq.put(e)
                        return
                if self.restart_on_error:
                    log.error('Restarting IPDB instance after '
                              'error:\n%s', traceback.format_exc())
                    try:
                        self.initdb()
                    except:
                        log.error('Error restarting DB:\n%s',
                                  traceback.format_exc())
                        return
                    continue
                else:
                    log.error('Emergency shutdown, cleanup manually')
                    raise RuntimeError('Emergency shutdown')

            for msg in messages:
                # Run pre-callbacks
                # NOTE: pre-callbacks are synchronous
                for (cuid, cb) in tuple(self._pre_callbacks.items()):
                    try:
                        cb(self, msg, msg['event'])
                    except:
                        pass

                with self.exclusive:
                    event = msg.get('event', None)
                    if event in self._event_map:
                        for func in self._event_map[event]:
                            func(msg)

                    # Post-callbacks
                    try:
                        self._cbq.put_nowait(msg)
                        if self._cbq_drop:
                            log.warning('dropped %d events',
                                        self._cbq_drop)
                            self._cbq_drop = 0
                    except queue.Full:
                        self._cbq_drop += 1
                    except Exception:
                        log.error('Emergency shutdown, cleanup manually')
                        raise RuntimeError('Emergency shutdown')

                    #
                    # Why not to put these two pieces of the code
                    # it in a routine?
                    #
                    # TODO: run performance tests with routines

                    # Users event queue
                    if self._evq:
                        try:
                            self._evq.put_nowait(msg)
                            if self._evq_drop:
                                log.warning("dropped %d events",
                                            self._evq_drop)
                                self._evq_drop = 0
                        except queue.Full:
                            self._evq_drop += 1
                        except Exception as e:
                            log.error('Emergency shutdown, cleanup manually')
                            raise RuntimeError('Emergency shutdown')
Esempio n. 2
0
class IPDB(object):
    '''
    The class that maintains information about network setup
    of the host. Monitoring netlink events allows it to react
    immediately. It uses no polling.
    '''
    def __init__(self,
                 nl=None,
                 mode='implicit',
                 restart_on_error=None,
                 nl_async=None,
                 nl_bind_groups=RTNL_GROUPS,
                 ignore_rtables=None,
                 callbacks=None,
                 sort_addresses=False,
                 plugins=None):
        plugins = plugins or ['interfaces', 'routes', 'rules']
        pmap = {'interfaces': interfaces, 'routes': routes, 'rules': rules}
        self.mode = mode
        self.sort_addresses = sort_addresses
        self._event_map = {}
        self._deferred = {}
        self._loaded = set()
        self._mthread = None
        self._nl_own = nl is None
        self._nl_async = config.ipdb_nl_async if nl_async is None else True
        self.mnl = None
        self.nl = nl
        self.nl_bind_groups = nl_bind_groups
        self._plugins = [pmap[x] for x in plugins if x in pmap]
        if isinstance(ignore_rtables, int):
            self._ignore_rtables = [
                ignore_rtables,
            ]
        elif isinstance(ignore_rtables, (list, tuple, set)):
            self._ignore_rtables = ignore_rtables
        else:
            self._ignore_rtables = []
        self._stop = False
        # see also 'register_callback'
        self._post_callbacks = {}
        self._pre_callbacks = {}
        self._cb_threads = {}

        # locks and events
        self.exclusive = threading.RLock()
        self._shutdown_lock = threading.Lock()

        # register callbacks
        #
        # examples::
        #   def cb1(ipdb, msg, event):
        #       print(event, msg)
        #   def cb2(...):
        #       ...
        #
        #   # default mode: post
        #   IPDB(callbacks=[cb1, cb2])
        #   # specify the mode explicitly
        #   IPDB(callbacks=[(cb1, 'pre'), (cb2, 'post')])
        #
        for cba in callbacks or []:
            if not isinstance(cba, (tuple, list, set)):
                cba = (cba, )
            self.register_callback(*cba)

        # load information
        self.restart_on_error = restart_on_error if \
            restart_on_error is not None else nl is None

        # init the database
        self.initdb()

        #
        atexit.register(self.release)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def initdb(self):
        # common event map, empty by default, so all the
        # events aer just ignored
        self.release(complete=False)
        self._stop = False
        # explicitly cleanup object references
        for event in tuple(self._event_map):
            del self._event_map[event]

        # if the command socket is not provided, create it
        if self._nl_own:
            self.nl = IPRoute()
        # setup monitoring socket
        self.mnl = self.nl.clone()
        try:
            self.mnl.bind(groups=self.nl_bind_groups, async=self._nl_async)
        except:
            self.mnl.close()
            if self._nl_own is None:
                self.nl.close()
            raise

        # explicitly cleanup references
        for key in tuple(self._deferred):
            del self._deferred[key]

        for module in self._plugins:
            if (module.groups & self.nl_bind_groups) != module.groups:
                continue
            for plugin in module.spec:
                self._deferred[plugin['name']] = module.spec
                if plugin['name'] in self._loaded:
                    delattr(self, plugin['name'])
                    self._loaded.remove(plugin['name'])

        # start the monitoring thread
        self._mthread = threading.Thread(name="IPDB event loop",
                                         target=self.serve_forever)
        self._mthread.setDaemon(True)
        self._mthread.start()

    def __getattribute__(self, name):
        deferred = super(IPDB, self).__getattribute__('_deferred')
        if name in deferred:
            register = []
            spec = deferred[name]
            for plugin in spec:
                obj = plugin['class'](self, **plugin['kwarg'])
                setattr(self, plugin['name'], obj)
                register.append(obj)
                self._loaded.add(plugin['name'])
                del deferred[plugin['name']]
            for obj in register:
                if hasattr(obj, '_register'):
                    obj._register()
                if hasattr(obj, '_event_map'):
                    for event in obj._event_map:
                        if event not in self._event_map:
                            self._event_map[event] = []
                        self._event_map[event].append(obj._event_map[event])
        return super(IPDB, self).__getattribute__(name)

    def register_callback(self, callback, mode='post'):
        '''
        IPDB callbacks are routines executed on a RT netlink
        message arrival. There are two types of callbacks:
        "post" and "pre" callbacks.

        ...

        "Post" callbacks are executed after the message is
        processed by IPDB and all corresponding objects are
        created or deleted. Using ipdb reference in "post"
        callbacks you will access the most up-to-date state
        of the IP database.

        "Post" callbacks are executed asynchronously in
        separate threads. These threads can work as long
        as you want them to. Callback threads are joined
        occasionally, so for a short time there can exist
        stopped threads.

        ...

        "Pre" callbacks are synchronous routines, executed
        before the message gets processed by IPDB. It gives
        you the way to patch arriving messages, but also
        places a restriction: until the callback exits, the
        main event IPDB loop is blocked.

        Normally, only "post" callbacks are required. But in
        some specific cases "pre" also can be useful.

        ...

        The routine, `register_callback()`, takes two arguments:
            - callback function
            - mode (optional, default="post")

        The callback should be a routine, that accepts three
        arguments::

            cb(ipdb, msg, action)

        Arguments are:

            - **ipdb** is a reference to IPDB instance, that invokes
                the callback.
            - **msg** is a message arrived
            - **action** is just a msg['event'] field

        E.g., to work on a new interface, you should catch
        action == 'RTM_NEWLINK' and with the interface index
        (arrived in msg['index']) get it from IPDB::

            index = msg['index']
            interface = ipdb.interfaces[index]
        '''
        lock = threading.Lock()

        def safe(*argv, **kwarg):
            with lock:
                callback(*argv, **kwarg)

        safe.hook = callback
        safe.lock = lock
        safe.uuid = uuid32()

        if mode == 'post':
            self._post_callbacks[safe.uuid] = safe
        elif mode == 'pre':
            self._pre_callbacks[safe.uuid] = safe
        return safe.uuid

    def unregister_callback(self, cuid, mode='post'):
        if mode == 'post':
            cbchain = self._post_callbacks
        elif mode == 'pre':
            cbchain = self._pre_callbacks
        else:
            raise KeyError('Unknown callback mode')
        safe = cbchain[cuid]
        with safe.lock:
            cbchain.pop(cuid)
        for t in tuple(self._cb_threads.get(cuid, ())):
            t.join(3)
        ret = self._cb_threads.get(cuid, ())
        return ret

    def release(self, complete=True):
        '''
        Shutdown IPDB instance and sync the state. Since
        IPDB is asyncronous, some operations continue in the
        background, e.g. callbacks. So, prior to exit the
        script, it is required to properly shutdown IPDB.

        The shutdown sequence is not forced in an interactive
        python session, since it is easier for users and there
        is enough time to sync the state. But for the scripts
        the `release()` call is required.
        '''
        with self._shutdown_lock:
            if self._stop:
                return
            self._stop = True
            if self.mnl is not None:
                # terminate the main loop
                for t in range(3):
                    try:
                        msg = ifinfmsg()
                        msg['index'] = 1
                        msg.reset()
                        self.mnl.put(msg, RTM_GETLINK)
                    except Exception as e:
                        logging.warning("shutdown error: %s", e)
                        # Just give up.
                        # We can not handle this case

            if self._mthread is not None:
                self._mthread.join()

            if self.mnl is not None:
                self.mnl.close()
                self.mnl = None
                if complete or self._nl_own:
                    self.nl.close()
                    self.nl = None

        with self.exclusive:
            # collect all the callbacks
            for cuid in tuple(self._cb_threads):
                for t in tuple(self._cb_threads[cuid]):
                    t.join()

            # flush all the objects
            def flush(idx):
                for key in tuple(idx.keys()):
                    try:
                        del idx[key]
                    except KeyError:
                        pass

            idx_list = []

            if 'interfaces' in self._loaded:
                for (key, dev) in self.by_name.items():
                    try:
                        # FIXME
                        self.interfaces._detach(key, dev['index'], dev.nlmsg)
                    except KeyError:
                        pass
                idx_list.append(self.ipaddr)
                idx_list.append(self.neighbours)

            if 'routes' in self._loaded:
                idx_list.extend(
                    [self.routes.tables[x] for x in self.routes.tables.keys()])

            if 'rules' in self._loaded:
                idx_list.append(self.rules)

            for idx in idx_list:
                flush(idx)

    def create(self, kind, ifname, reuse=False, **kwarg):
        return self.interfaces.add(kind, ifname, reuse, **kwarg)

    def commit(self, transactions=None, phase=1):
        # what to commit: either from transactions argument, or from
        # started transactions on existing objects
        if transactions is None:
            # collect interface transactions
            txlist = [(x, x.current_tx) for x in self.by_name.values()
                      if x.local_tx.values()]
            # collect route transactions
            for table in self.routes.tables.keys():
                txlist.extend([(x, x.current_tx)
                               for x in self.routes.tables[table]
                               if x.local_tx.values()])
            txlist = sorted(txlist,
                            key=lambda x: x[1]['ipdb_priority'],
                            reverse=True)
            transactions = txlist

        snapshots = []
        removed = []

        try:
            for (target, tx) in transactions:
                if target['ipdb_scope'] == 'detached':
                    continue
                if tx['ipdb_scope'] == 'remove':
                    tx['ipdb_scope'] = 'shadow'
                    removed.append((target, tx))
                if phase == 1:
                    s = (target, target.pick(detached=True))
                    snapshots.append(s)
                target.commit(transaction=tx,
                              commit_phase=phase,
                              commit_mask=phase)
        except Exception:
            if phase == 1:
                self.fallen = transactions
                self.commit(transactions=snapshots, phase=2)
            raise
        else:
            if phase == 1:
                for (target, tx) in removed:
                    target['ipdb_scope'] = 'detached'
                    target.detach()
        finally:
            if phase == 1:
                for (target, tx) in transactions:
                    target.drop(tx.uid)

    def watchdog(self, action='RTM_NEWLINK', **kwarg):
        return Watchdog(self, action, kwarg)

    def serve_forever(self):
        ###
        # Main monitoring cycle. It gets messages from the
        # default iproute queue and updates objects in the
        # database.
        #
        # Should not be called manually.
        ###

        while not self._stop:
            try:
                messages = self.mnl.get()
                ##
                # Check it again
                #
                # NOTE: one should not run callbacks or
                # anything like that after setting the
                # _stop flag, since IPDB is not valid
                # anymore
                if self._stop:
                    break
            except:
                log.error('Restarting IPDB instance after '
                          'error:\n%s', traceback.format_exc())
                if self.restart_on_error:
                    try:
                        self.initdb()
                    except:
                        log.error('Error restarting DB:\n%s',
                                  traceback.format_exc())
                        return
                    continue
                else:
                    raise RuntimeError('Emergency shutdown')

            for msg in messages:
                # Run pre-callbacks
                # NOTE: pre-callbacks are synchronous
                for (cuid, cb) in tuple(self._pre_callbacks.items()):
                    try:
                        cb(self, msg, msg['event'])
                    except:
                        pass

                with self.exclusive:
                    event = msg.get('event', None)
                    if event in self._event_map:
                        for func in self._event_map[event]:
                            func(msg)

                # run post-callbacks
                # NOTE: post-callbacks are asynchronous
                for (cuid, cb) in tuple(self._post_callbacks.items()):
                    t = threading.Thread(name="IPDB callback %s" % (id(cb)),
                                         target=cb,
                                         args=(self, msg, msg['event']))
                    t.start()
                    if cuid not in self._cb_threads:
                        self._cb_threads[cuid] = set()
                    self._cb_threads[cuid].add(t)

                # occasionally join cb threads
                for cuid in tuple(self._cb_threads):
                    for t in tuple(self._cb_threads.get(cuid, ())):
                        t.join(0)
                        if not t.is_alive():
                            try:
                                self._cb_threads[cuid].remove(t)
                            except KeyError:
                                pass
                            if len(self._cb_threads.get(cuid, ())) == 0:
                                del self._cb_threads[cuid]

    def init_ipaddr_set(self):
        if self.sort_addresses:
            return SortedIPaddrSet()
        else:
            return IPaddrSet()
Esempio n. 3
0
class IPDB(object):
    '''
    The class that maintains information about network setup
    of the host. Monitoring netlink events allows it to react
    immediately. It uses no polling.
    '''

    def __init__(self, nl=None, mode='implicit',
                 restart_on_error=None, nl_async=None,
                 nl_bind_groups=RTNL_GROUPS,
                 ignore_rtables=None, callbacks=None,
                 sort_addresses=False):
        self.mode = mode
        self.sort_addresses = sort_addresses
        self._event_map = {}
        self._deferred = {}
        self._loaded = set()
        self._mthread = None
        self._nl_own = nl is None
        self._nl_async = config.ipdb_nl_async if nl_async is None else True
        self.mnl = None
        self.nl = nl
        self.nl_bind_groups = nl_bind_groups
        self._plugins = [interface, route, rule]
        if isinstance(ignore_rtables, int):
            self._ignore_rtables = [ignore_rtables, ]
        elif isinstance(ignore_rtables, (list, tuple, set)):
            self._ignore_rtables = ignore_rtables
        else:
            self._ignore_rtables = []
        self._stop = False
        # see also 'register_callback'
        self._post_callbacks = {}
        self._pre_callbacks = {}
        self._cb_threads = {}

        # locks and events
        self.exclusive = threading.RLock()
        self._shutdown_lock = threading.Lock()

        # register callbacks
        #
        # examples::
        #   def cb1(ipdb, msg, event):
        #       print(event, msg)
        #   def cb2(...):
        #       ...
        #
        #   # default mode: post
        #   IPDB(callbacks=[cb1, cb2])
        #   # specify the mode explicitly
        #   IPDB(callbacks=[(cb1, 'pre'), (cb2, 'post')])
        #
        for cba in callbacks or []:
            if not isinstance(cba, (tuple, list, set)):
                cba = (cba, )
            self.register_callback(*cba)

        # load information
        self.restart_on_error = restart_on_error if \
            restart_on_error is not None else nl is None

        # init the database
        self.initdb()

        #
        atexit.register(self.release)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def initdb(self):
        # common event map, empty by default, so all the
        # events aer just ignored
        self.release(complete=False)
        self._stop = False
        # explicitly cleanup object references
        for event in tuple(self._event_map):
            del self._event_map[event]

        # if the command socket is not provided, create it
        if self._nl_own:
            self.nl = IPRoute()
        # setup monitoring socket
        self.mnl = self.nl.clone()
        try:
            self.mnl.bind(groups=self.nl_bind_groups, async=self._nl_async)
        except:
            self.mnl.close()
            if self._nl_own is None:
                self.nl.close()
            raise

        # explicitly cleanup references
        for key in tuple(self._deferred):
            del self._deferred[key]

        for module in self._plugins:
            if (module.groups & self.nl_bind_groups) != module.groups:
                continue
            for plugin in module.spec:
                self._deferred[plugin['name']] = module.spec
                if plugin['name'] in self._loaded:
                    delattr(self, plugin['name'])
                    self._loaded.remove(plugin['name'])

        # start the monitoring thread
        self._mthread = threading.Thread(name="IPDB event loop",
                                         target=self.serve_forever)
        self._mthread.setDaemon(True)
        self._mthread.start()

    def __getattribute__(self, name):
        deferred = super(IPDB, self).__getattribute__('_deferred')
        if name in deferred:
            register = []
            spec = deferred[name]
            for plugin in spec:
                obj = plugin['class'](self, **plugin['kwarg'])
                setattr(self, plugin['name'], obj)
                register.append(obj)
                self._loaded.add(plugin['name'])
                del deferred[plugin['name']]
            for obj in register:
                if hasattr(obj, '_register'):
                    obj._register()
                if hasattr(obj, '_event_map'):
                    for event in obj._event_map:
                        if event not in self._event_map:
                            self._event_map[event] = []
                        self._event_map[event].append(obj._event_map[event])
        return super(IPDB, self).__getattribute__(name)

    def register_callback(self, callback, mode='post'):
        '''
        IPDB callbacks are routines executed on a RT netlink
        message arrival. There are two types of callbacks:
        "post" and "pre" callbacks.

        ...

        "Post" callbacks are executed after the message is
        processed by IPDB and all corresponding objects are
        created or deleted. Using ipdb reference in "post"
        callbacks you will access the most up-to-date state
        of the IP database.

        "Post" callbacks are executed asynchronously in
        separate threads. These threads can work as long
        as you want them to. Callback threads are joined
        occasionally, so for a short time there can exist
        stopped threads.

        ...

        "Pre" callbacks are synchronous routines, executed
        before the message gets processed by IPDB. It gives
        you the way to patch arriving messages, but also
        places a restriction: until the callback exits, the
        main event IPDB loop is blocked.

        Normally, only "post" callbacks are required. But in
        some specific cases "pre" also can be useful.

        ...

        The routine, `register_callback()`, takes two arguments:
            - callback function
            - mode (optional, default="post")

        The callback should be a routine, that accepts three
        arguments::

            cb(ipdb, msg, action)

        Arguments are:

            - **ipdb** is a reference to IPDB instance, that invokes
                the callback.
            - **msg** is a message arrived
            - **action** is just a msg['event'] field

        E.g., to work on a new interface, you should catch
        action == 'RTM_NEWLINK' and with the interface index
        (arrived in msg['index']) get it from IPDB::

            index = msg['index']
            interface = ipdb.interfaces[index]
        '''
        lock = threading.Lock()

        def safe(*argv, **kwarg):
            with lock:
                callback(*argv, **kwarg)

        safe.hook = callback
        safe.lock = lock
        safe.uuid = uuid32()

        if mode == 'post':
            self._post_callbacks[safe.uuid] = safe
        elif mode == 'pre':
            self._pre_callbacks[safe.uuid] = safe
        return safe.uuid

    def unregister_callback(self, cuid, mode='post'):
        if mode == 'post':
            cbchain = self._post_callbacks
        elif mode == 'pre':
            cbchain = self._pre_callbacks
        else:
            raise KeyError('Unknown callback mode')
        safe = cbchain[cuid]
        with safe.lock:
            cbchain.pop(cuid)
        for t in tuple(self._cb_threads.get(cuid, ())):
            t.join(3)
        ret = self._cb_threads.get(cuid, ())
        return ret

    def release(self, complete=True):
        '''
        Shutdown IPDB instance and sync the state. Since
        IPDB is asyncronous, some operations continue in the
        background, e.g. callbacks. So, prior to exit the
        script, it is required to properly shutdown IPDB.

        The shutdown sequence is not forced in an interactive
        python session, since it is easier for users and there
        is enough time to sync the state. But for the scripts
        the `release()` call is required.
        '''
        with self._shutdown_lock:
            if self._stop:
                return
            self._stop = True
            if self.mnl is not None:
                # terminate the main loop
                for t in range(3):
                    try:
                        msg = ifinfmsg()
                        msg['index'] = 1
                        msg.reset()
                        self.mnl.put(msg, RTM_GETLINK)
                    except Exception as e:
                        logging.warning("shutdown error: %s", e)
                        # Just give up.
                        # We can not handle this case

            if self._mthread is not None:
                self._mthread.join()

            if self.mnl is not None:
                self.mnl.close()
                self.mnl = None
                if complete or self._nl_own:
                    self.nl.close()
                    self.nl = None

        with self.exclusive:
                # collect all the callbacks
                for cuid in tuple(self._cb_threads):
                    for t in tuple(self._cb_threads[cuid]):
                        t.join()

                # flush all the objects
                def flush(idx):
                    for key in tuple(idx.keys()):
                        try:
                            del idx[key]
                        except KeyError:
                            pass
                idx_list = []

                if 'interfaces' in self._loaded:
                    for (key, dev) in self.by_name.items():
                        try:
                            # FIXME
                            self.interfaces._detach(key,
                                                    dev['index'],
                                                    dev.nlmsg)
                        except KeyError:
                            pass
                    idx_list.append(self.ipaddr)
                    idx_list.append(self.neighbours)

                if 'routes' in self._loaded:
                    idx_list.extend([self.routes.tables[x] for x
                                     in self.routes.tables.keys()])

                if 'rules' in self._loaded:
                    idx_list.append(self.rules)

                for idx in idx_list:
                    flush(idx)

    def create(self, kind, ifname, reuse=False, **kwarg):
        return self.interfaces.add(kind, ifname, reuse, **kwarg)

    def commit(self, transactions=None, phase=1):
        # what to commit: either from transactions argument, or from
        # started transactions on existing objects
        if transactions is None:
            # collect interface transactions
            txlist = [(x, x.current_tx) for x
                      in self.by_name.values() if x.local_tx.values()]
            # collect route transactions
            for table in self.routes.tables.keys():
                txlist.extend([(x, x.current_tx) for x in
                               self.routes.tables[table]
                               if x.local_tx.values()])
            txlist = sorted(txlist,
                            key=lambda x: x[1]['ipdb_priority'],
                            reverse=True)
            transactions = txlist

        snapshots = []
        removed = []

        try:
            for (target, tx) in transactions:
                if target['ipdb_scope'] == 'detached':
                    continue
                if tx['ipdb_scope'] == 'remove':
                    tx['ipdb_scope'] = 'shadow'
                    removed.append((target, tx))
                if phase == 1:
                    s = (target, target.pick(detached=True))
                    snapshots.append(s)
                target.commit(transaction=tx,
                              commit_phase=phase,
                              commit_mask=phase)
        except Exception:
            if phase == 1:
                self.fallen = transactions
                self.commit(transactions=snapshots, phase=2)
            raise
        else:
            if phase == 1:
                for (target, tx) in removed:
                    target['ipdb_scope'] = 'detached'
                    target.detach()
        finally:
            if phase == 1:
                for (target, tx) in transactions:
                    target.drop(tx.uid)

    def watchdog(self, action='RTM_NEWLINK', **kwarg):
        return Watchdog(self, action, kwarg)

    def serve_forever(self):
        ###
        # Main monitoring cycle. It gets messages from the
        # default iproute queue and updates objects in the
        # database.
        #
        # Should not be called manually.
        ###

        while not self._stop:
            try:
                messages = self.mnl.get()
                ##
                # Check it again
                #
                # NOTE: one should not run callbacks or
                # anything like that after setting the
                # _stop flag, since IPDB is not valid
                # anymore
                if self._stop:
                    break
            except:
                log.error('Restarting IPDB instance after '
                          'error:\n%s', traceback.format_exc())
                if self.restart_on_error:
                    try:
                        self.initdb()
                    except:
                        log.error('Error restarting DB:\n%s',
                                  traceback.format_exc())
                        return
                    continue
                else:
                    raise RuntimeError('Emergency shutdown')

            for msg in messages:
                # Run pre-callbacks
                # NOTE: pre-callbacks are synchronous
                for (cuid, cb) in tuple(self._pre_callbacks.items()):
                    try:
                        cb(self, msg, msg['event'])
                    except:
                        pass

                with self.exclusive:
                    event = msg.get('event', None)
                    if event in self._event_map:
                        for func in self._event_map[event]:
                            func(msg)

                # run post-callbacks
                # NOTE: post-callbacks are asynchronous
                for (cuid, cb) in tuple(self._post_callbacks.items()):
                    t = threading.Thread(name="IPDB callback %s" % (id(cb)),
                                         target=cb,
                                         args=(self, msg, msg['event']))
                    t.start()
                    if cuid not in self._cb_threads:
                        self._cb_threads[cuid] = set()
                    self._cb_threads[cuid].add(t)

                # occasionally join cb threads
                for cuid in tuple(self._cb_threads):
                    for t in tuple(self._cb_threads.get(cuid, ())):
                        t.join(0)
                        if not t.is_alive():
                            try:
                                self._cb_threads[cuid].remove(t)
                            except KeyError:
                                pass
                            if len(self._cb_threads.get(cuid, ())) == 0:
                                del self._cb_threads[cuid]

    def init_ipaddr_set(self):
        if self.sort_addresses:
            return SortedIPaddrSet()
        else:
            return IPaddrSet()
Esempio n. 4
0
class IPDB(object):
    '''
    The class that maintains information about network setup
    of the host. Monitoring netlink events allows it to react
    immediately. It uses no polling.
    '''

    def __init__(self, nl=None, mode='implicit',
                 restart_on_error=None, nl_async=None,
                 sndbuf=1048576, rcvbuf=1048576,
                 nl_bind_groups=RTMGRP_DEFAULTS,
                 ignore_rtables=None, callbacks=None,
                 sort_addresses=False, plugins=None):
        plugins = plugins or ['interfaces', 'routes', 'rules']
        pmap = {'interfaces': interfaces,
                'routes': routes,
                'rules': rules}
        self.mode = mode
        self.txdrop = False
        self._stdout = sys.stdout
        self._ipaddr_set = SortedIPaddrSet if sort_addresses else IPaddrSet
        self._event_map = {}
        self._deferred = {}
        self._ensure = []
        self._loaded = set()
        self._mthread = None
        self._nl_own = nl is None
        self._nl_async = config.ipdb_nl_async if nl_async is None else True
        self.mnl = None
        self.nl = nl
        self._sndbuf = sndbuf
        self._rcvbuf = rcvbuf
        self.nl_bind_groups = nl_bind_groups
        self._plugins = [pmap[x] for x in plugins if x in pmap]
        if isinstance(ignore_rtables, int):
            self._ignore_rtables = [ignore_rtables, ]
        elif isinstance(ignore_rtables, (list, tuple, set)):
            self._ignore_rtables = ignore_rtables
        else:
            self._ignore_rtables = []
        self._stop = False
        # see also 'register_callback'
        self._post_callbacks = {}
        self._pre_callbacks = {}

        # local event queues
        # - callbacks event queue
        self._cbq = queue.Queue(maxsize=8192)
        self._cbq_drop = 0
        # - users event queue
        self._evq = None
        self._evq_lock = threading.Lock()
        self._evq_drop = 0

        # locks and events
        self.exclusive = threading.RLock()
        self._shutdown_lock = threading.Lock()

        # register callbacks
        #
        # examples::
        #   def cb1(ipdb, msg, event):
        #       print(event, msg)
        #   def cb2(...):
        #       ...
        #
        #   # default mode: post
        #   IPDB(callbacks=[cb1, cb2])
        #   # specify the mode explicitly
        #   IPDB(callbacks=[(cb1, 'pre'), (cb2, 'post')])
        #
        for cba in callbacks or []:
            if not isinstance(cba, (tuple, list, set)):
                cba = (cba, )
            self.register_callback(*cba)

        # load information
        self.restart_on_error = restart_on_error if \
            restart_on_error is not None else nl is None

        # init the database
        self.initdb()

        # init the dir() cache
        self.__dir_cache__ = [i for i in self.__class__.__dict__.keys()
                              if i[0] != '_']
        self.__dir_cache__.extend(list(self._deferred.keys()))

        def cleanup(ref):
            ipdb_obj = ref()
            if ipdb_obj is not None:
                ipdb_obj.release()
        atexit.register(cleanup, weakref.ref(self))

    def __dir__(self):
        return self.__dir_cache__

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def _flush_db(self):

        def flush(idx):
            for key in tuple(idx.keys()):
                try:
                    del idx[key]
                except KeyError:
                    pass
        idx_list = []
        if 'interfaces' in self._loaded:
            for (key, dev) in self.by_name.items():
                try:
                    # FIXME
                    self.interfaces._detach(key,
                                            dev['index'],
                                            dev.nlmsg)
                except KeyError:
                    pass
            idx_list.append(self.ipaddr)
            idx_list.append(self.neighbours)
        if 'routes' in self._loaded:
            idx_list.extend([self.routes.tables[x] for x
                             in self.routes.tables.keys()])
        if 'rules' in self._loaded:
            idx_list.append(self.rules)
        for idx in idx_list:
            flush(idx)

    def initdb(self):

        # flush all the DB objects
        with self.exclusive:

            # explicitly cleanup object references
            for event in tuple(self._event_map):
                del self._event_map[event]

            self._flush_db()

            # if the command socket is not provided, create it
            if self._nl_own:
                if self.nl is not None:
                    self.nl.close()
                self.nl = IPRoute(sndbuf=self._sndbuf, rcvbuf=self._rcvbuf)
            # setup monitoring socket
            if self.mnl is not None:
                self._flush_mnl()
                self.mnl.close()
            self.mnl = self.nl.clone()
            try:
                self.mnl.bind(groups=self.nl_bind_groups,
                              async_cache=self._nl_async)
            except:
                self.mnl.close()
                if self._nl_own is None:
                    self.nl.close()
                raise

            # explicitly cleanup references
            for key in tuple(self._deferred):
                del self._deferred[key]

            for module in self._plugins:
                if (module.groups & self.nl_bind_groups) != module.groups:
                    continue
                for plugin in module.spec:
                    self._deferred[plugin['name']] = module.spec
                    if plugin['name'] in self._loaded:
                        delattr(self, plugin['name'])
                        self._loaded.remove(plugin['name'])

            # start service threads
            for tspec in (('_mthread', '_serve_main', 'IPDB main event loop'),
                          ('_cthread', '_serve_cb', 'IPDB cb event loop')):
                tg = getattr(self, tspec[0], None)
                if not getattr(tg, 'is_alive', lambda: False)():
                    tx = threading.Thread(name=tspec[2],
                                          target=getattr(self, tspec[1]))
                    setattr(self, tspec[0], tx)
                    tx.setDaemon(True)
                    tx.start()

    def __getattribute__(self, name):
        deferred = super(IPDB, self).__getattribute__('_deferred')
        if name in deferred:
            register = []
            spec = deferred[name]
            for plugin in spec:
                obj = plugin['class'](self, **plugin['kwarg'])
                setattr(self, plugin['name'], obj)
                register.append(obj)
                self._loaded.add(plugin['name'])
                del deferred[plugin['name']]
            for obj in register:
                if hasattr(obj, '_register'):
                    obj._register()
                if hasattr(obj, '_event_map'):
                    for event in obj._event_map:
                        if event not in self._event_map:
                            self._event_map[event] = []
                        self._event_map[event].append(obj._event_map[event])
        return super(IPDB, self).__getattribute__(name)

    def register_callback(self, callback, mode='post'):
        '''
        IPDB callbacks are routines executed on a RT netlink
        message arrival. There are two types of callbacks:
        "post" and "pre" callbacks.

        ...

        "Post" callbacks are executed after the message is
        processed by IPDB and all corresponding objects are
        created or deleted. Using ipdb reference in "post"
        callbacks you will access the most up-to-date state
        of the IP database.

        "Post" callbacks are executed asynchronously in
        separate threads. These threads can work as long
        as you want them to. Callback threads are joined
        occasionally, so for a short time there can exist
        stopped threads.

        ...

        "Pre" callbacks are synchronous routines, executed
        before the message gets processed by IPDB. It gives
        you the way to patch arriving messages, but also
        places a restriction: until the callback exits, the
        main event IPDB loop is blocked.

        Normally, only "post" callbacks are required. But in
        some specific cases "pre" also can be useful.

        ...

        The routine, `register_callback()`, takes two arguments:
            - callback function
            - mode (optional, default="post")

        The callback should be a routine, that accepts three
        arguments::

            cb(ipdb, msg, action)

        Arguments are:

            - **ipdb** is a reference to IPDB instance, that invokes
                the callback.
            - **msg** is a message arrived
            - **action** is just a msg['event'] field

        E.g., to work on a new interface, you should catch
        action == 'RTM_NEWLINK' and with the interface index
        (arrived in msg['index']) get it from IPDB::

            index = msg['index']
            interface = ipdb.interfaces[index]
        '''
        lock = threading.Lock()

        def safe(*argv, **kwarg):
            with lock:
                callback(*argv, **kwarg)

        safe.hook = callback
        safe.lock = lock
        safe.uuid = uuid32()

        if mode == 'post':
            self._post_callbacks[safe.uuid] = safe
        elif mode == 'pre':
            self._pre_callbacks[safe.uuid] = safe
        else:
            raise KeyError('Unknown callback mode')
        return safe.uuid

    def unregister_callback(self, cuid, mode='post'):
        if mode == 'post':
            cbchain = self._post_callbacks
        elif mode == 'pre':
            cbchain = self._pre_callbacks
        else:
            raise KeyError('Unknown callback mode')
        safe = cbchain[cuid]
        with safe.lock:
            ret = cbchain.pop(cuid)
        return ret

    def eventqueue(self, qsize=8192, block=True, timeout=None):
        '''
        Initializes event queue and returns event queue context manager.
        Once the context manager is initialized, events start to be collected,
        so it is possible to read initial state from the system witout losing
        last moment changes, and once that is done, start processing events.

        Example:

            ipdb = IPDB()
            with ipdb.eventqueue() as evq:
                my_state = ipdb.<needed_attribute>...
                for msg in evq:
                    update_state_by_msg(my_state, msg)
        '''
        return _evq_context(self, qsize, block, timeout)

    def eventloop(self, qsize=8192, block=True, timeout=None):
        """
        Event generator for simple cases when there is no need for initial
        state setup. Initialize event queue and yield events as they happen.
        """
        with self.eventqueue(qsize=qsize, block=block, timeout=timeout) as evq:
            for msg in evq:
                yield msg

    def release(self):
        '''
        Shutdown IPDB instance and sync the state. Since
        IPDB is asyncronous, some operations continue in the
        background, e.g. callbacks. So, prior to exit the
        script, it is required to properly shutdown IPDB.

        The shutdown sequence is not forced in an interactive
        python session, since it is easier for users and there
        is enough time to sync the state. But for the scripts
        the `release()` call is required.
        '''
        with self._shutdown_lock:
            if self._stop:
                log.warning("shutdown in progress")
                return
            self._stop = True
            self._cbq.put(ShutdownException("shutdown"))

            if self._mthread is not None:
                self._flush_mnl()
                self._mthread.join()

            if self.mnl is not None:
                self.mnl.close()
                self.mnl = None

            if self._nl_own:
                self.nl.close()
                self.nl = None

            self._flush_db()

    def _flush_mnl(self):
        if self.mnl is not None:
            # terminate the main loop
            for t in range(3):
                try:
                    msg = ifinfmsg()
                    msg['index'] = 1
                    msg.reset()
                    self.mnl.put(msg, RTM_GETLINK)
                except Exception as e:
                    log.error("shutdown error: %s", e)
                    # Just give up.
                    # We can not handle this case

    def create(self, kind, ifname, reuse=False, **kwarg):
        return self.interfaces.add(kind, ifname, reuse, **kwarg)

    def ensure(self, cmd='add', reachable=None, condition=None):
        if cmd == 'reset':
            self._ensure = []
        elif cmd == 'run':
            for f in self._ensure:
                f()
        elif cmd == 'add':
            if isinstance(reachable, basestring):
                reachable = reachable.split(':')
                if len(reachable) == 1:
                    f = partial(test_reachable_icmp, reachable[0])
                else:
                    raise NotImplementedError()
                self._ensure.append(f)
            else:
                if sys.stdin.isatty():
                    pprint(self._ensure, stream=self._stdout)
        elif cmd == 'print':
            pprint(self._ensure, stream=self._stdout)
        elif cmd == 'get':
            return self._ensure
        else:
            raise NotImplementedError()

    def items(self):
        # TODO: add support for filters?

        # iterate interfaces
        for ifname in getattr(self, 'by_name', {}):
            yield (('interfaces', ifname), self.interfaces[ifname])

        # iterate routes
        for table in getattr(getattr(self, 'routes', None),
                             'tables', {}):
            for key, route in self.routes.tables[table].items():
                yield (('routes', table, key), route)

    def dump(self):
        ret = {}
        for key, obj in self.items():
            ptr = ret
            for step in key[:-1]:
                if step not in ptr:
                    ptr[step] = {}
                ptr = ptr[step]
            ptr[key[-1]] = obj
        return ret

    def load(self, config, ptr=None):
        if ptr is None:
            ptr = self

        for key in config:
            obj = getattr(ptr, key, None)
            if obj is not None:
                if hasattr(obj, 'load'):
                    obj.load(config[key])
                else:
                    self.load(config[key], ptr=obj)
            elif hasattr(ptr, 'add'):
                ptr.add(**config[key])

        return self

    def review(self):
        ret = {}
        for key, obj in self.items():
            ptr = ret
            try:
                rev = obj.review()
            except TypeError:
                continue

            for step in key[:-1]:
                if step not in ptr:
                    ptr[step] = {}
                ptr = ptr[step]
            ptr[key[-1]] = rev

        if not ret:
            raise TypeError('no transaction started')
        return ret

    def drop(self):
        ok = False
        for key, obj in self.items():
            try:
                obj.drop()
            except TypeError:
                continue
            ok = True
        if not ok:
            raise TypeError('no transaction started')

    def commit(self, transactions=None, phase=1):
        # what to commit: either from transactions argument, or from
        # started transactions on existing objects
        if transactions is None:
            # collect interface transactions
            txlist = [(x, x.current_tx) for x
                      in getattr(self, 'by_name', {}).values()
                      if x.local_tx.values()]
            # collect route transactions
            for table in getattr(getattr(self, 'routes', None),
                                 'tables', {}).keys():
                txlist.extend([(x, x.current_tx) for x in
                               self.routes.tables[table]
                               if x.local_tx.values()])
            transactions = txlist

        snapshots = []
        removed = []

        tx_ipdb_prio = []
        tx_main = []
        tx_prio1 = []
        tx_prio2 = []
        tx_prio3 = []
        for (target, tx) in transactions:
            # 8<------------------------------
            # first -- explicit priorities
            if tx['ipdb_priority']:
                tx_ipdb_prio.append((target, tx))
                continue
            # 8<------------------------------
            # routes
            if isinstance(target, BaseRoute):
                tx_prio3.append((target, tx))
                continue
            # 8<------------------------------
            # intefaces
            kind = target.get('kind', None)
            if kind in ('vlan', 'vxlan', 'gre', 'tuntap', 'vti', 'vti6',
                        'vrf'):
                tx_prio1.append((target, tx))
            elif kind in ('bridge', 'bond'):
                tx_prio2.append((target, tx))
            else:
                tx_main.append((target, tx))
            # 8<------------------------------

        # explicitly sorted transactions
        tx_ipdb_prio = sorted(tx_ipdb_prio,
                              key=lambda x: x[1]['ipdb_priority'],
                              reverse=True)

        # FIXME: this should be documented
        #
        # The final transactions order:
        # 1. any txs with ipdb_priority (sorted by that field)
        #
        # Then come default priorities (no ipdb_priority specified):
        # 2. all the rest
        # 3. vlan, vxlan, gre, tuntap, vti, vrf
        # 4. bridge, bond
        # 5. routes
        transactions = tx_ipdb_prio + tx_main + tx_prio1 + tx_prio2 + tx_prio3

        try:
            for (target, tx) in transactions:
                if target['ipdb_scope'] == 'detached':
                    continue
                if tx['ipdb_scope'] == 'remove':
                    tx['ipdb_scope'] = 'shadow'
                    removed.append((target, tx))
                if phase == 1:
                    s = (target, target.pick(detached=True))
                    snapshots.append(s)
                # apply the changes, but NO rollback -- only phase 1
                target.commit(transaction=tx,
                              commit_phase=phase,
                              commit_mask=phase)
                # if the commit above fails, the next code
                # branch will run rollbacks
        except Exception:
            if phase == 1:
                # run rollbacks for ALL the collected transactions,
                # even successful ones
                self.fallen = transactions
                txs = filter(lambda x: not ('create' ==
                                            x[0]['ipdb_scope'] ==
                                            x[1]['ipdb_scope']), snapshots)
                self.commit(transactions=txs, phase=2)
            raise
        else:
            if phase == 1:
                for (target, tx) in removed:
                    target['ipdb_scope'] = 'detached'
                    target.detach()
        finally:
            if phase == 1:
                for (target, tx) in transactions:
                    target.drop(tx.uid)

        return self

    def watchdog(self, wdops='RTM_NEWLINK', **kwarg):
        return Watchdog(self, wdops, kwarg)

    def _serve_cb(self):
        ###
        # Callbacks thread working on a dedicated event queue.
        ###

        while not self._stop:
            msg = self._cbq.get()
            self._cbq.task_done()
            if isinstance(msg, ShutdownException):
                return
            elif isinstance(msg, Exception):
                raise msg
            for cb in tuple(self._post_callbacks.values()):
                try:
                    cb(self, msg, msg['event'])
                except:
                    pass

    def _serve_main(self):
        ###
        # Main monitoring cycle. It gets messages from the
        # default iproute queue and updates objects in the
        # database.
        ###

        while not self._stop:
            try:
                messages = self.mnl.get()
                ##
                # Check it again
                #
                # NOTE: one should not run callbacks or
                # anything like that after setting the
                # _stop flag, since IPDB is not valid
                # anymore
                if self._stop:
                    break
            except Exception as e:
                with self.exclusive:
                    if self._evq:
                        self._evq.put(e)
                        return
                if self.restart_on_error:
                    log.error('Restarting IPDB instance after '
                              'error:\n%s', traceback.format_exc())
                    try:
                        self.initdb()
                    except:
                        log.error('Error restarting DB:\n%s',
                                  traceback.format_exc())
                        return
                    continue
                else:
                    log.error('Emergency shutdown, cleanup manually')
                    raise RuntimeError('Emergency shutdown')

            for msg in messages:
                # Run pre-callbacks
                # NOTE: pre-callbacks are synchronous
                for (cuid, cb) in tuple(self._pre_callbacks.items()):
                    try:
                        cb(self, msg, msg['event'])
                    except:
                        pass

                with self.exclusive:
                    event = msg.get('event', None)
                    if event in self._event_map:
                        for func in self._event_map[event]:
                            func(msg)

                    # Post-callbacks
                    try:
                        self._cbq.put_nowait(msg)
                        if self._cbq_drop:
                            log.warning('dropped %d events',
                                        self._cbq_drop)
                            self._cbq_drop = 0
                    except queue.Full:
                        self._cbq_drop += 1
                    except Exception:
                        log.error('Emergency shutdown, cleanup manually')
                        raise RuntimeError('Emergency shutdown')

                    #
                    # Why not to put these two pieces of the code
                    # it in a routine?
                    #
                    # TODO: run performance tests with routines

                    # Users event queue
                    if self._evq:
                        try:
                            self._evq.put_nowait(msg)
                            if self._evq_drop:
                                log.warning("dropped %d events",
                                            self._evq_drop)
                                self._evq_drop = 0
                        except queue.Full:
                            self._evq_drop += 1
                        except Exception as e:
                            log.error('Emergency shutdown, cleanup manually')
                            raise RuntimeError('Emergency shutdown')