class IPDB(object): ''' The class that maintains information about network setup of the host. Monitoring netlink events allows it to react immediately. It uses no polling. ''' def __init__(self, nl=None, mode='implicit', restart_on_error=None, nl_async=None, sndbuf=1048576, rcvbuf=1048576, nl_bind_groups=RTMGRP_DEFAULTS, ignore_rtables=None, callbacks=None, sort_addresses=False, plugins=None): plugins = plugins or ['interfaces', 'routes', 'rules'] pmap = {'interfaces': interfaces, 'routes': routes, 'rules': rules} self.mode = mode self.txdrop = False self._stdout = sys.stdout self._ipaddr_set = SortedIPaddrSet if sort_addresses else IPaddrSet self._event_map = {} self._deferred = {} self._ensure = [] self._loaded = set() self._mthread = None self._nl_own = nl is None self._nl_async = config.ipdb_nl_async if nl_async is None else True self.mnl = None self.nl = nl self._sndbuf = sndbuf self._rcvbuf = rcvbuf self.nl_bind_groups = nl_bind_groups self._plugins = [pmap[x] for x in plugins if x in pmap] if isinstance(ignore_rtables, int): self._ignore_rtables = [ignore_rtables, ] elif isinstance(ignore_rtables, (list, tuple, set)): self._ignore_rtables = ignore_rtables else: self._ignore_rtables = [] self._stop = False # see also 'register_callback' self._post_callbacks = {} self._pre_callbacks = {} # local event queues # - callbacks event queue self._cbq = queue.Queue(maxsize=8192) self._cbq_drop = 0 # - users event queue self._evq = None self._evq_lock = threading.Lock() self._evq_drop = 0 # locks and events self.exclusive = threading.RLock() self._shutdown_lock = threading.Lock() # register callbacks # # examples:: # def cb1(ipdb, msg, event): # print(event, msg) # def cb2(...): # ... # # # default mode: post # IPDB(callbacks=[cb1, cb2]) # # specify the mode explicitly # IPDB(callbacks=[(cb1, 'pre'), (cb2, 'post')]) # for cba in callbacks or []: if not isinstance(cba, (tuple, list, set)): cba = (cba, ) self.register_callback(*cba) # load information self.restart_on_error = restart_on_error if \ restart_on_error is not None else nl is None # init the database self.initdb() # init the dir() cache self.__dir_cache__ = [i for i in self.__class__.__dict__.keys() if i[0] != '_'] self.__dir_cache__.extend(list(self._deferred.keys())) def cleanup(ref): ipdb_obj = ref() if ipdb_obj is not None: ipdb_obj.release() atexit.register(cleanup, weakref.ref(self)) def __dir__(self): return self.__dir_cache__ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.release() def _flush_db(self): def flush(idx): for key in tuple(idx.keys()): try: del idx[key] except KeyError: pass idx_list = [] if 'interfaces' in self._loaded: for (key, dev) in self.by_name.items(): try: # FIXME self.interfaces._detach(key, dev['index'], dev.nlmsg) except KeyError: pass idx_list.append(self.ipaddr) idx_list.append(self.neighbours) if 'routes' in self._loaded: idx_list.extend([self.routes.tables[x] for x in self.routes.tables.keys()]) if 'rules' in self._loaded: idx_list.append(self.rules) for idx in idx_list: flush(idx) def initdb(self): # flush all the DB objects with self.exclusive: # explicitly cleanup object references for event in tuple(self._event_map): del self._event_map[event] self._flush_db() # if the command socket is not provided, create it if self._nl_own: if self.nl is not None: self.nl.close() self.nl = IPRoute(sndbuf=self._sndbuf, rcvbuf=self._rcvbuf) # setup monitoring socket if self.mnl is not None: self._flush_mnl() self.mnl.close() self.mnl = self.nl.clone() try: self.mnl.bind(groups=self.nl_bind_groups, async_cache=self._nl_async) except: self.mnl.close() if self._nl_own is None: self.nl.close() raise # explicitly cleanup references for key in tuple(self._deferred): del self._deferred[key] for module in self._plugins: if (module.groups & self.nl_bind_groups) != module.groups: continue for plugin in module.spec: self._deferred[plugin['name']] = module.spec if plugin['name'] in self._loaded: delattr(self, plugin['name']) self._loaded.remove(plugin['name']) # start service threads for tspec in (('_mthread', '_serve_main', 'IPDB main event loop'), ('_cthread', '_serve_cb', 'IPDB cb event loop')): tg = getattr(self, tspec[0], None) if not getattr(tg, 'is_alive', lambda: False)(): tx = threading.Thread(name=tspec[2], target=getattr(self, tspec[1])) setattr(self, tspec[0], tx) tx.setDaemon(True) tx.start() def __getattribute__(self, name): deferred = super(IPDB, self).__getattribute__('_deferred') if name in deferred: register = [] spec = deferred[name] for plugin in spec: obj = plugin['class'](self, **plugin['kwarg']) setattr(self, plugin['name'], obj) register.append(obj) self._loaded.add(plugin['name']) del deferred[plugin['name']] for obj in register: if hasattr(obj, '_register'): obj._register() if hasattr(obj, '_event_map'): for event in obj._event_map: if event not in self._event_map: self._event_map[event] = [] self._event_map[event].append(obj._event_map[event]) return super(IPDB, self).__getattribute__(name) def register_callback(self, callback, mode='post'): ''' IPDB callbacks are routines executed on a RT netlink message arrival. There are two types of callbacks: "post" and "pre" callbacks. ... "Post" callbacks are executed after the message is processed by IPDB and all corresponding objects are created or deleted. Using ipdb reference in "post" callbacks you will access the most up-to-date state of the IP database. "Post" callbacks are executed asynchronously in separate threads. These threads can work as long as you want them to. Callback threads are joined occasionally, so for a short time there can exist stopped threads. ... "Pre" callbacks are synchronous routines, executed before the message gets processed by IPDB. It gives you the way to patch arriving messages, but also places a restriction: until the callback exits, the main event IPDB loop is blocked. Normally, only "post" callbacks are required. But in some specific cases "pre" also can be useful. ... The routine, `register_callback()`, takes two arguments: - callback function - mode (optional, default="post") The callback should be a routine, that accepts three arguments:: cb(ipdb, msg, action) Arguments are: - **ipdb** is a reference to IPDB instance, that invokes the callback. - **msg** is a message arrived - **action** is just a msg['event'] field E.g., to work on a new interface, you should catch action == 'RTM_NEWLINK' and with the interface index (arrived in msg['index']) get it from IPDB:: index = msg['index'] interface = ipdb.interfaces[index] ''' lock = threading.Lock() def safe(*argv, **kwarg): with lock: callback(*argv, **kwarg) safe.hook = callback safe.lock = lock safe.uuid = uuid32() if mode == 'post': self._post_callbacks[safe.uuid] = safe elif mode == 'pre': self._pre_callbacks[safe.uuid] = safe else: raise KeyError('Unknown callback mode') return safe.uuid def unregister_callback(self, cuid, mode='post'): if mode == 'post': cbchain = self._post_callbacks elif mode == 'pre': cbchain = self._pre_callbacks else: raise KeyError('Unknown callback mode') safe = cbchain[cuid] with safe.lock: ret = cbchain.pop(cuid) return ret def eventqueue(self, qsize=8192, block=True, timeout=None): ''' Initializes event queue and returns event queue context manager. Once the context manager is initialized, events start to be collected, so it is possible to read initial state from the system witout losing last moment changes, and once that is done, start processing events. Example: ipdb = IPDB() with ipdb.eventqueue() as evq: my_state = ipdb.<needed_attribute>... for msg in evq: update_state_by_msg(my_state, msg) ''' return _evq_context(self, qsize, block, timeout) def eventloop(self, qsize=8192, block=True, timeout=None): """ Event generator for simple cases when there is no need for initial state setup. Initialize event queue and yield events as they happen. """ with self.eventqueue(qsize=qsize, block=block, timeout=timeout) as evq: for msg in evq: yield msg def release(self): ''' Shutdown IPDB instance and sync the state. Since IPDB is asyncronous, some operations continue in the background, e.g. callbacks. So, prior to exit the script, it is required to properly shutdown IPDB. The shutdown sequence is not forced in an interactive python session, since it is easier for users and there is enough time to sync the state. But for the scripts the `release()` call is required. ''' with self._shutdown_lock: if self._stop: log.warning("shutdown in progress") return self._stop = True self._cbq.put(ShutdownException("shutdown")) if self._mthread is not None: self._flush_mnl() self._mthread.join() if self.mnl is not None: self.mnl.close() self.mnl = None if self._nl_own: self.nl.close() self.nl = None self._flush_db() def _flush_mnl(self): if self.mnl is not None: # terminate the main loop for t in range(3): try: msg = ifinfmsg() msg['index'] = 1 msg.reset() self.mnl.put(msg, RTM_GETLINK) except Exception as e: log.error("shutdown error: %s", e) # Just give up. # We can not handle this case def create(self, kind, ifname, reuse=False, **kwarg): return self.interfaces.add(kind, ifname, reuse, **kwarg) def ensure(self, cmd='add', reachable=None, condition=None): if cmd == 'reset': self._ensure = [] elif cmd == 'run': for f in self._ensure: f() elif cmd == 'add': if isinstance(reachable, basestring): reachable = reachable.split(':') if len(reachable) == 1: f = partial(test_reachable_icmp, reachable[0]) else: raise NotImplementedError() self._ensure.append(f) else: if sys.stdin.isatty(): pprint(self._ensure, stream=self._stdout) elif cmd == 'print': pprint(self._ensure, stream=self._stdout) elif cmd == 'get': return self._ensure else: raise NotImplementedError() def items(self): # TODO: add support for filters? # iterate interfaces for ifname in getattr(self, 'by_name', {}): yield (('interfaces', ifname), self.interfaces[ifname]) # iterate routes for table in getattr(getattr(self, 'routes', None), 'tables', {}): for key, route in self.routes.tables[table].items(): yield (('routes', table, key), route) def dump(self): ret = {} for key, obj in self.items(): ptr = ret for step in key[:-1]: if step not in ptr: ptr[step] = {} ptr = ptr[step] ptr[key[-1]] = obj return ret def load(self, config, ptr=None): if ptr is None: ptr = self for key in config: obj = getattr(ptr, key, None) if obj is not None: if hasattr(obj, 'load'): obj.load(config[key]) else: self.load(config[key], ptr=obj) elif hasattr(ptr, 'add'): ptr.add(**config[key]) return self def review(self): ret = {} for key, obj in self.items(): ptr = ret try: rev = obj.review() except TypeError: continue for step in key[:-1]: if step not in ptr: ptr[step] = {} ptr = ptr[step] ptr[key[-1]] = rev if not ret: raise TypeError('no transaction started') return ret def drop(self): ok = False for key, obj in self.items(): try: obj.drop() except TypeError: continue ok = True if not ok: raise TypeError('no transaction started') def commit(self, transactions=None, phase=1): # what to commit: either from transactions argument, or from # started transactions on existing objects if transactions is None: # collect interface transactions txlist = [(x, x.current_tx) for x in getattr(self, 'by_name', {}).values() if x.local_tx.values()] # collect route transactions for table in getattr(getattr(self, 'routes', None), 'tables', {}).keys(): txlist.extend([(x, x.current_tx) for x in self.routes.tables[table] if x.local_tx.values()]) transactions = txlist snapshots = [] removed = [] tx_ipdb_prio = [] tx_main = [] tx_prio1 = [] tx_prio2 = [] tx_prio3 = [] for (target, tx) in transactions: # 8<------------------------------ # first -- explicit priorities if tx['ipdb_priority']: tx_ipdb_prio.append((target, tx)) continue # 8<------------------------------ # routes if isinstance(target, BaseRoute): tx_prio3.append((target, tx)) continue # 8<------------------------------ # intefaces kind = target.get('kind', None) if kind in ('vlan', 'vxlan', 'gre', 'tuntap', 'vti', 'vti6', 'vrf'): tx_prio1.append((target, tx)) elif kind in ('bridge', 'bond'): tx_prio2.append((target, tx)) else: tx_main.append((target, tx)) # 8<------------------------------ # explicitly sorted transactions tx_ipdb_prio = sorted(tx_ipdb_prio, key=lambda x: x[1]['ipdb_priority'], reverse=True) # FIXME: this should be documented # # The final transactions order: # 1. any txs with ipdb_priority (sorted by that field) # # Then come default priorities (no ipdb_priority specified): # 2. all the rest # 3. vlan, vxlan, gre, tuntap, vti, vrf # 4. bridge, bond # 5. routes transactions = tx_ipdb_prio + tx_main + tx_prio1 + tx_prio2 + tx_prio3 try: for (target, tx) in transactions: if target['ipdb_scope'] == 'detached': continue if tx['ipdb_scope'] == 'remove': tx['ipdb_scope'] = 'shadow' removed.append((target, tx)) if phase == 1: s = (target, target.pick(detached=True)) snapshots.append(s) # apply the changes, but NO rollback -- only phase 1 target.commit(transaction=tx, commit_phase=phase, commit_mask=phase) # if the commit above fails, the next code # branch will run rollbacks except Exception: if phase == 1: # run rollbacks for ALL the collected transactions, # even successful ones self.fallen = transactions txs = filter(lambda x: not ('create' == x[0]['ipdb_scope'] == x[1]['ipdb_scope']), snapshots) self.commit(transactions=txs, phase=2) raise else: if phase == 1: for (target, tx) in removed: target['ipdb_scope'] = 'detached' target.detach() finally: if phase == 1: for (target, tx) in transactions: target.drop(tx.uid) return self def watchdog(self, wdops='RTM_NEWLINK', **kwarg): return Watchdog(self, wdops, kwarg) def _serve_cb(self): ### # Callbacks thread working on a dedicated event queue. ### while not self._stop: msg = self._cbq.get() self._cbq.task_done() if isinstance(msg, ShutdownException): return elif isinstance(msg, Exception): raise msg for cb in tuple(self._post_callbacks.values()): try: cb(self, msg, msg['event']) except: pass def _serve_main(self): ### # Main monitoring cycle. It gets messages from the # default iproute queue and updates objects in the # database. ### while not self._stop: try: messages = self.mnl.get() ## # Check it again # # NOTE: one should not run callbacks or # anything like that after setting the # _stop flag, since IPDB is not valid # anymore if self._stop: break except Exception as e: with self.exclusive: if self._evq: self._evq.put(e) return if self.restart_on_error: log.error('Restarting IPDB instance after ' 'error:\n%s', traceback.format_exc()) try: self.initdb() except: log.error('Error restarting DB:\n%s', traceback.format_exc()) return continue else: log.error('Emergency shutdown, cleanup manually') raise RuntimeError('Emergency shutdown') for msg in messages: # Run pre-callbacks # NOTE: pre-callbacks are synchronous for (cuid, cb) in tuple(self._pre_callbacks.items()): try: cb(self, msg, msg['event']) except: pass with self.exclusive: event = msg.get('event', None) if event in self._event_map: for func in self._event_map[event]: func(msg) # Post-callbacks try: self._cbq.put_nowait(msg) if self._cbq_drop: log.warning('dropped %d events', self._cbq_drop) self._cbq_drop = 0 except queue.Full: self._cbq_drop += 1 except Exception: log.error('Emergency shutdown, cleanup manually') raise RuntimeError('Emergency shutdown') # # Why not to put these two pieces of the code # it in a routine? # # TODO: run performance tests with routines # Users event queue if self._evq: try: self._evq.put_nowait(msg) if self._evq_drop: log.warning("dropped %d events", self._evq_drop) self._evq_drop = 0 except queue.Full: self._evq_drop += 1 except Exception as e: log.error('Emergency shutdown, cleanup manually') raise RuntimeError('Emergency shutdown')
class IPDB(object): ''' The class that maintains information about network setup of the host. Monitoring netlink events allows it to react immediately. It uses no polling. ''' def __init__(self, nl=None, mode='implicit', restart_on_error=None, nl_async=None, nl_bind_groups=RTNL_GROUPS, ignore_rtables=None, callbacks=None, sort_addresses=False, plugins=None): plugins = plugins or ['interfaces', 'routes', 'rules'] pmap = {'interfaces': interfaces, 'routes': routes, 'rules': rules} self.mode = mode self.sort_addresses = sort_addresses self._event_map = {} self._deferred = {} self._loaded = set() self._mthread = None self._nl_own = nl is None self._nl_async = config.ipdb_nl_async if nl_async is None else True self.mnl = None self.nl = nl self.nl_bind_groups = nl_bind_groups self._plugins = [pmap[x] for x in plugins if x in pmap] if isinstance(ignore_rtables, int): self._ignore_rtables = [ ignore_rtables, ] elif isinstance(ignore_rtables, (list, tuple, set)): self._ignore_rtables = ignore_rtables else: self._ignore_rtables = [] self._stop = False # see also 'register_callback' self._post_callbacks = {} self._pre_callbacks = {} self._cb_threads = {} # locks and events self.exclusive = threading.RLock() self._shutdown_lock = threading.Lock() # register callbacks # # examples:: # def cb1(ipdb, msg, event): # print(event, msg) # def cb2(...): # ... # # # default mode: post # IPDB(callbacks=[cb1, cb2]) # # specify the mode explicitly # IPDB(callbacks=[(cb1, 'pre'), (cb2, 'post')]) # for cba in callbacks or []: if not isinstance(cba, (tuple, list, set)): cba = (cba, ) self.register_callback(*cba) # load information self.restart_on_error = restart_on_error if \ restart_on_error is not None else nl is None # init the database self.initdb() # atexit.register(self.release) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.release() def initdb(self): # common event map, empty by default, so all the # events aer just ignored self.release(complete=False) self._stop = False # explicitly cleanup object references for event in tuple(self._event_map): del self._event_map[event] # if the command socket is not provided, create it if self._nl_own: self.nl = IPRoute() # setup monitoring socket self.mnl = self.nl.clone() try: self.mnl.bind(groups=self.nl_bind_groups, async=self._nl_async) except: self.mnl.close() if self._nl_own is None: self.nl.close() raise # explicitly cleanup references for key in tuple(self._deferred): del self._deferred[key] for module in self._plugins: if (module.groups & self.nl_bind_groups) != module.groups: continue for plugin in module.spec: self._deferred[plugin['name']] = module.spec if plugin['name'] in self._loaded: delattr(self, plugin['name']) self._loaded.remove(plugin['name']) # start the monitoring thread self._mthread = threading.Thread(name="IPDB event loop", target=self.serve_forever) self._mthread.setDaemon(True) self._mthread.start() def __getattribute__(self, name): deferred = super(IPDB, self).__getattribute__('_deferred') if name in deferred: register = [] spec = deferred[name] for plugin in spec: obj = plugin['class'](self, **plugin['kwarg']) setattr(self, plugin['name'], obj) register.append(obj) self._loaded.add(plugin['name']) del deferred[plugin['name']] for obj in register: if hasattr(obj, '_register'): obj._register() if hasattr(obj, '_event_map'): for event in obj._event_map: if event not in self._event_map: self._event_map[event] = [] self._event_map[event].append(obj._event_map[event]) return super(IPDB, self).__getattribute__(name) def register_callback(self, callback, mode='post'): ''' IPDB callbacks are routines executed on a RT netlink message arrival. There are two types of callbacks: "post" and "pre" callbacks. ... "Post" callbacks are executed after the message is processed by IPDB and all corresponding objects are created or deleted. Using ipdb reference in "post" callbacks you will access the most up-to-date state of the IP database. "Post" callbacks are executed asynchronously in separate threads. These threads can work as long as you want them to. Callback threads are joined occasionally, so for a short time there can exist stopped threads. ... "Pre" callbacks are synchronous routines, executed before the message gets processed by IPDB. It gives you the way to patch arriving messages, but also places a restriction: until the callback exits, the main event IPDB loop is blocked. Normally, only "post" callbacks are required. But in some specific cases "pre" also can be useful. ... The routine, `register_callback()`, takes two arguments: - callback function - mode (optional, default="post") The callback should be a routine, that accepts three arguments:: cb(ipdb, msg, action) Arguments are: - **ipdb** is a reference to IPDB instance, that invokes the callback. - **msg** is a message arrived - **action** is just a msg['event'] field E.g., to work on a new interface, you should catch action == 'RTM_NEWLINK' and with the interface index (arrived in msg['index']) get it from IPDB:: index = msg['index'] interface = ipdb.interfaces[index] ''' lock = threading.Lock() def safe(*argv, **kwarg): with lock: callback(*argv, **kwarg) safe.hook = callback safe.lock = lock safe.uuid = uuid32() if mode == 'post': self._post_callbacks[safe.uuid] = safe elif mode == 'pre': self._pre_callbacks[safe.uuid] = safe return safe.uuid def unregister_callback(self, cuid, mode='post'): if mode == 'post': cbchain = self._post_callbacks elif mode == 'pre': cbchain = self._pre_callbacks else: raise KeyError('Unknown callback mode') safe = cbchain[cuid] with safe.lock: cbchain.pop(cuid) for t in tuple(self._cb_threads.get(cuid, ())): t.join(3) ret = self._cb_threads.get(cuid, ()) return ret def release(self, complete=True): ''' Shutdown IPDB instance and sync the state. Since IPDB is asyncronous, some operations continue in the background, e.g. callbacks. So, prior to exit the script, it is required to properly shutdown IPDB. The shutdown sequence is not forced in an interactive python session, since it is easier for users and there is enough time to sync the state. But for the scripts the `release()` call is required. ''' with self._shutdown_lock: if self._stop: return self._stop = True if self.mnl is not None: # terminate the main loop for t in range(3): try: msg = ifinfmsg() msg['index'] = 1 msg.reset() self.mnl.put(msg, RTM_GETLINK) except Exception as e: logging.warning("shutdown error: %s", e) # Just give up. # We can not handle this case if self._mthread is not None: self._mthread.join() if self.mnl is not None: self.mnl.close() self.mnl = None if complete or self._nl_own: self.nl.close() self.nl = None with self.exclusive: # collect all the callbacks for cuid in tuple(self._cb_threads): for t in tuple(self._cb_threads[cuid]): t.join() # flush all the objects def flush(idx): for key in tuple(idx.keys()): try: del idx[key] except KeyError: pass idx_list = [] if 'interfaces' in self._loaded: for (key, dev) in self.by_name.items(): try: # FIXME self.interfaces._detach(key, dev['index'], dev.nlmsg) except KeyError: pass idx_list.append(self.ipaddr) idx_list.append(self.neighbours) if 'routes' in self._loaded: idx_list.extend( [self.routes.tables[x] for x in self.routes.tables.keys()]) if 'rules' in self._loaded: idx_list.append(self.rules) for idx in idx_list: flush(idx) def create(self, kind, ifname, reuse=False, **kwarg): return self.interfaces.add(kind, ifname, reuse, **kwarg) def commit(self, transactions=None, phase=1): # what to commit: either from transactions argument, or from # started transactions on existing objects if transactions is None: # collect interface transactions txlist = [(x, x.current_tx) for x in self.by_name.values() if x.local_tx.values()] # collect route transactions for table in self.routes.tables.keys(): txlist.extend([(x, x.current_tx) for x in self.routes.tables[table] if x.local_tx.values()]) txlist = sorted(txlist, key=lambda x: x[1]['ipdb_priority'], reverse=True) transactions = txlist snapshots = [] removed = [] try: for (target, tx) in transactions: if target['ipdb_scope'] == 'detached': continue if tx['ipdb_scope'] == 'remove': tx['ipdb_scope'] = 'shadow' removed.append((target, tx)) if phase == 1: s = (target, target.pick(detached=True)) snapshots.append(s) target.commit(transaction=tx, commit_phase=phase, commit_mask=phase) except Exception: if phase == 1: self.fallen = transactions self.commit(transactions=snapshots, phase=2) raise else: if phase == 1: for (target, tx) in removed: target['ipdb_scope'] = 'detached' target.detach() finally: if phase == 1: for (target, tx) in transactions: target.drop(tx.uid) def watchdog(self, action='RTM_NEWLINK', **kwarg): return Watchdog(self, action, kwarg) def serve_forever(self): ### # Main monitoring cycle. It gets messages from the # default iproute queue and updates objects in the # database. # # Should not be called manually. ### while not self._stop: try: messages = self.mnl.get() ## # Check it again # # NOTE: one should not run callbacks or # anything like that after setting the # _stop flag, since IPDB is not valid # anymore if self._stop: break except: log.error('Restarting IPDB instance after ' 'error:\n%s', traceback.format_exc()) if self.restart_on_error: try: self.initdb() except: log.error('Error restarting DB:\n%s', traceback.format_exc()) return continue else: raise RuntimeError('Emergency shutdown') for msg in messages: # Run pre-callbacks # NOTE: pre-callbacks are synchronous for (cuid, cb) in tuple(self._pre_callbacks.items()): try: cb(self, msg, msg['event']) except: pass with self.exclusive: event = msg.get('event', None) if event in self._event_map: for func in self._event_map[event]: func(msg) # run post-callbacks # NOTE: post-callbacks are asynchronous for (cuid, cb) in tuple(self._post_callbacks.items()): t = threading.Thread(name="IPDB callback %s" % (id(cb)), target=cb, args=(self, msg, msg['event'])) t.start() if cuid not in self._cb_threads: self._cb_threads[cuid] = set() self._cb_threads[cuid].add(t) # occasionally join cb threads for cuid in tuple(self._cb_threads): for t in tuple(self._cb_threads.get(cuid, ())): t.join(0) if not t.is_alive(): try: self._cb_threads[cuid].remove(t) except KeyError: pass if len(self._cb_threads.get(cuid, ())) == 0: del self._cb_threads[cuid] def init_ipaddr_set(self): if self.sort_addresses: return SortedIPaddrSet() else: return IPaddrSet()
class IPDB(object): ''' The class that maintains information about network setup of the host. Monitoring netlink events allows it to react immediately. It uses no polling. ''' def __init__(self, nl=None, mode='implicit', restart_on_error=None, nl_async=None, nl_bind_groups=RTNL_GROUPS, ignore_rtables=None, callbacks=None, sort_addresses=False): self.mode = mode self.sort_addresses = sort_addresses self._event_map = {} self._deferred = {} self._loaded = set() self._mthread = None self._nl_own = nl is None self._nl_async = config.ipdb_nl_async if nl_async is None else True self.mnl = None self.nl = nl self.nl_bind_groups = nl_bind_groups self._plugins = [interface, route, rule] if isinstance(ignore_rtables, int): self._ignore_rtables = [ignore_rtables, ] elif isinstance(ignore_rtables, (list, tuple, set)): self._ignore_rtables = ignore_rtables else: self._ignore_rtables = [] self._stop = False # see also 'register_callback' self._post_callbacks = {} self._pre_callbacks = {} self._cb_threads = {} # locks and events self.exclusive = threading.RLock() self._shutdown_lock = threading.Lock() # register callbacks # # examples:: # def cb1(ipdb, msg, event): # print(event, msg) # def cb2(...): # ... # # # default mode: post # IPDB(callbacks=[cb1, cb2]) # # specify the mode explicitly # IPDB(callbacks=[(cb1, 'pre'), (cb2, 'post')]) # for cba in callbacks or []: if not isinstance(cba, (tuple, list, set)): cba = (cba, ) self.register_callback(*cba) # load information self.restart_on_error = restart_on_error if \ restart_on_error is not None else nl is None # init the database self.initdb() # atexit.register(self.release) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.release() def initdb(self): # common event map, empty by default, so all the # events aer just ignored self.release(complete=False) self._stop = False # explicitly cleanup object references for event in tuple(self._event_map): del self._event_map[event] # if the command socket is not provided, create it if self._nl_own: self.nl = IPRoute() # setup monitoring socket self.mnl = self.nl.clone() try: self.mnl.bind(groups=self.nl_bind_groups, async=self._nl_async) except: self.mnl.close() if self._nl_own is None: self.nl.close() raise # explicitly cleanup references for key in tuple(self._deferred): del self._deferred[key] for module in self._plugins: if (module.groups & self.nl_bind_groups) != module.groups: continue for plugin in module.spec: self._deferred[plugin['name']] = module.spec if plugin['name'] in self._loaded: delattr(self, plugin['name']) self._loaded.remove(plugin['name']) # start the monitoring thread self._mthread = threading.Thread(name="IPDB event loop", target=self.serve_forever) self._mthread.setDaemon(True) self._mthread.start() def __getattribute__(self, name): deferred = super(IPDB, self).__getattribute__('_deferred') if name in deferred: register = [] spec = deferred[name] for plugin in spec: obj = plugin['class'](self, **plugin['kwarg']) setattr(self, plugin['name'], obj) register.append(obj) self._loaded.add(plugin['name']) del deferred[plugin['name']] for obj in register: if hasattr(obj, '_register'): obj._register() if hasattr(obj, '_event_map'): for event in obj._event_map: if event not in self._event_map: self._event_map[event] = [] self._event_map[event].append(obj._event_map[event]) return super(IPDB, self).__getattribute__(name) def register_callback(self, callback, mode='post'): ''' IPDB callbacks are routines executed on a RT netlink message arrival. There are two types of callbacks: "post" and "pre" callbacks. ... "Post" callbacks are executed after the message is processed by IPDB and all corresponding objects are created or deleted. Using ipdb reference in "post" callbacks you will access the most up-to-date state of the IP database. "Post" callbacks are executed asynchronously in separate threads. These threads can work as long as you want them to. Callback threads are joined occasionally, so for a short time there can exist stopped threads. ... "Pre" callbacks are synchronous routines, executed before the message gets processed by IPDB. It gives you the way to patch arriving messages, but also places a restriction: until the callback exits, the main event IPDB loop is blocked. Normally, only "post" callbacks are required. But in some specific cases "pre" also can be useful. ... The routine, `register_callback()`, takes two arguments: - callback function - mode (optional, default="post") The callback should be a routine, that accepts three arguments:: cb(ipdb, msg, action) Arguments are: - **ipdb** is a reference to IPDB instance, that invokes the callback. - **msg** is a message arrived - **action** is just a msg['event'] field E.g., to work on a new interface, you should catch action == 'RTM_NEWLINK' and with the interface index (arrived in msg['index']) get it from IPDB:: index = msg['index'] interface = ipdb.interfaces[index] ''' lock = threading.Lock() def safe(*argv, **kwarg): with lock: callback(*argv, **kwarg) safe.hook = callback safe.lock = lock safe.uuid = uuid32() if mode == 'post': self._post_callbacks[safe.uuid] = safe elif mode == 'pre': self._pre_callbacks[safe.uuid] = safe return safe.uuid def unregister_callback(self, cuid, mode='post'): if mode == 'post': cbchain = self._post_callbacks elif mode == 'pre': cbchain = self._pre_callbacks else: raise KeyError('Unknown callback mode') safe = cbchain[cuid] with safe.lock: cbchain.pop(cuid) for t in tuple(self._cb_threads.get(cuid, ())): t.join(3) ret = self._cb_threads.get(cuid, ()) return ret def release(self, complete=True): ''' Shutdown IPDB instance and sync the state. Since IPDB is asyncronous, some operations continue in the background, e.g. callbacks. So, prior to exit the script, it is required to properly shutdown IPDB. The shutdown sequence is not forced in an interactive python session, since it is easier for users and there is enough time to sync the state. But for the scripts the `release()` call is required. ''' with self._shutdown_lock: if self._stop: return self._stop = True if self.mnl is not None: # terminate the main loop for t in range(3): try: msg = ifinfmsg() msg['index'] = 1 msg.reset() self.mnl.put(msg, RTM_GETLINK) except Exception as e: logging.warning("shutdown error: %s", e) # Just give up. # We can not handle this case if self._mthread is not None: self._mthread.join() if self.mnl is not None: self.mnl.close() self.mnl = None if complete or self._nl_own: self.nl.close() self.nl = None with self.exclusive: # collect all the callbacks for cuid in tuple(self._cb_threads): for t in tuple(self._cb_threads[cuid]): t.join() # flush all the objects def flush(idx): for key in tuple(idx.keys()): try: del idx[key] except KeyError: pass idx_list = [] if 'interfaces' in self._loaded: for (key, dev) in self.by_name.items(): try: # FIXME self.interfaces._detach(key, dev['index'], dev.nlmsg) except KeyError: pass idx_list.append(self.ipaddr) idx_list.append(self.neighbours) if 'routes' in self._loaded: idx_list.extend([self.routes.tables[x] for x in self.routes.tables.keys()]) if 'rules' in self._loaded: idx_list.append(self.rules) for idx in idx_list: flush(idx) def create(self, kind, ifname, reuse=False, **kwarg): return self.interfaces.add(kind, ifname, reuse, **kwarg) def commit(self, transactions=None, phase=1): # what to commit: either from transactions argument, or from # started transactions on existing objects if transactions is None: # collect interface transactions txlist = [(x, x.current_tx) for x in self.by_name.values() if x.local_tx.values()] # collect route transactions for table in self.routes.tables.keys(): txlist.extend([(x, x.current_tx) for x in self.routes.tables[table] if x.local_tx.values()]) txlist = sorted(txlist, key=lambda x: x[1]['ipdb_priority'], reverse=True) transactions = txlist snapshots = [] removed = [] try: for (target, tx) in transactions: if target['ipdb_scope'] == 'detached': continue if tx['ipdb_scope'] == 'remove': tx['ipdb_scope'] = 'shadow' removed.append((target, tx)) if phase == 1: s = (target, target.pick(detached=True)) snapshots.append(s) target.commit(transaction=tx, commit_phase=phase, commit_mask=phase) except Exception: if phase == 1: self.fallen = transactions self.commit(transactions=snapshots, phase=2) raise else: if phase == 1: for (target, tx) in removed: target['ipdb_scope'] = 'detached' target.detach() finally: if phase == 1: for (target, tx) in transactions: target.drop(tx.uid) def watchdog(self, action='RTM_NEWLINK', **kwarg): return Watchdog(self, action, kwarg) def serve_forever(self): ### # Main monitoring cycle. It gets messages from the # default iproute queue and updates objects in the # database. # # Should not be called manually. ### while not self._stop: try: messages = self.mnl.get() ## # Check it again # # NOTE: one should not run callbacks or # anything like that after setting the # _stop flag, since IPDB is not valid # anymore if self._stop: break except: log.error('Restarting IPDB instance after ' 'error:\n%s', traceback.format_exc()) if self.restart_on_error: try: self.initdb() except: log.error('Error restarting DB:\n%s', traceback.format_exc()) return continue else: raise RuntimeError('Emergency shutdown') for msg in messages: # Run pre-callbacks # NOTE: pre-callbacks are synchronous for (cuid, cb) in tuple(self._pre_callbacks.items()): try: cb(self, msg, msg['event']) except: pass with self.exclusive: event = msg.get('event', None) if event in self._event_map: for func in self._event_map[event]: func(msg) # run post-callbacks # NOTE: post-callbacks are asynchronous for (cuid, cb) in tuple(self._post_callbacks.items()): t = threading.Thread(name="IPDB callback %s" % (id(cb)), target=cb, args=(self, msg, msg['event'])) t.start() if cuid not in self._cb_threads: self._cb_threads[cuid] = set() self._cb_threads[cuid].add(t) # occasionally join cb threads for cuid in tuple(self._cb_threads): for t in tuple(self._cb_threads.get(cuid, ())): t.join(0) if not t.is_alive(): try: self._cb_threads[cuid].remove(t) except KeyError: pass if len(self._cb_threads.get(cuid, ())) == 0: del self._cb_threads[cuid] def init_ipaddr_set(self): if self.sort_addresses: return SortedIPaddrSet() else: return IPaddrSet()