def __init__(self, ipdb=None, mode=None, parent=None, uid=None): # if ipdb is not None: self.nl = ipdb.nl self.ipdb = ipdb else: self.nl = None self.ipdb = None # self._parent = None if parent is not None: self._mode = mode or parent._mode self._parent = parent elif ipdb is not None: self._mode = mode or ipdb.mode else: self._mode = mode or 'implicit' # self.nlmsg = None self.uid = uid or uuid32() self.last_error = None self._commit_hooks = [] self._sids = [] self._ts = threading.local() self._snapshots = {} self.global_tx = {} self._targets = {} self._local_targets = {} self._write_lock = threading.RLock() self._direct_state = State(self._write_lock) self._linked_sets = self._linked_sets or set() # for i in self._fields: TransactionalBase.__setitem__(self, i, None)
def __setitem__(self, direct, key, value): if not direct: # automatically set target on the active transaction, # which must be started prior to that call transaction = self.current_tx transaction[key] = value if value is not None: transaction._targets[key] = threading.Event() else: # set the item TransactionalBase.__setitem__(self, key, value) # update on local targets with self._write_lock: if key in self._local_targets: func = self._fields_cmp.get(key, lambda x, y: x == y) if func(value, self._local_targets[key].value): self._local_targets[key].set() # cascade update on nested targets for tn in tuple(self.global_tx.values()): if (key in tn._targets) and (key in tn): if self._fields_cmp.\ get(key, lambda x, y: x == y)(value, tn[key]): tn._targets[key].set()
def __delitem__(self, direct, key): # firstly set targets self[key] = None # then continue with delete if not direct: transaction = self.current_tx if key in transaction: del transaction[key] else: TransactionalBase.__delitem__(self, key)
def initdb(self, nl=None): self.nl = nl or IPRoute() self.mnl = self.nl.clone() # resolvers self.interfaces = TransactionalBase() self.routes = RoutingTableSet(ipdb=self, ignore_rtables=self._ignore_rtables) self.by_name = View(src=self.interfaces, constraint=lambda k, v: isinstance(k, basestring)) self.by_index = View(src=self.interfaces, constraint=lambda k, v: isinstance(k, int)) # caches self.ipaddr = {} self.neighbours = {} try: self.mnl.bind(async=self._nl_async) # load information links = self.nl.get_links() for link in links: self._interface_add(link, skip_slaves=True) for link in links: self.update_slaves(link) # bridge info links = self.nl.get_vlans() for link in links: self._interface_add(link) # for msg in self.nl.get_addr(): self._addr_add(msg) for msg in self.nl.get_neighbours(): self._neigh_add(msg) for msg in self.nl.get_routes(family=AF_INET): self._route_add(msg) for msg in self.nl.get_routes(family=AF_INET6): self._route_add(msg) for msg in self.nl.get_routes(family=AF_MPLS): self._route_add(msg) except Exception as e: logging.error('initdb error: %s', e) logging.error(traceback.format_exc()) try: self.nl.close() self.mnl.close() except: pass raise e
class IPDB(object): ''' The class that maintains information about network setup of the host. Monitoring netlink events allows it to react immediately. It uses no polling. ''' def __init__(self, nl=None, mode='implicit', restart_on_error=None, nl_async=None, debug=False, ignore_rtables=None): self.mode = mode self.debug = debug if isinstance(ignore_rtables, int): self._ignore_rtables = [ignore_rtables, ] elif isinstance(ignore_rtables, (list, tuple, set)): self._ignore_rtables = ignore_rtables else: self._ignore_rtables = [] self.iclass = Interface self._nl_async = config.ipdb_nl_async if nl_async is None else True self._stop = False # see also 'register_callback' self._post_callbacks = {} self._pre_callbacks = {} self._cb_threads = {} # locks and events self.exclusive = threading.RLock() self._shutdown_lock = threading.Lock() # load information self.restart_on_error = restart_on_error if \ restart_on_error is not None else nl is None self.initdb(nl) # start monitoring thread self._mthread = threading.Thread(target=self.serve_forever) self._mthread.setDaemon(True) self._mthread.start() # atexit.register(self.release) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.release() def initdb(self, nl=None): self.nl = nl or IPRoute() self.mnl = self.nl.clone() # resolvers self.interfaces = TransactionalBase() self.routes = RoutingTableSet(ipdb=self, ignore_rtables=self._ignore_rtables) self.by_name = View(src=self.interfaces, constraint=lambda k, v: isinstance(k, basestring)) self.by_index = View(src=self.interfaces, constraint=lambda k, v: isinstance(k, int)) # caches self.ipaddr = {} self.neighbours = {} try: self.mnl.bind(async=self._nl_async) # load information links = self.nl.get_links() for link in links: self._interface_add(link, skip_slaves=True) for link in links: self.update_slaves(link) # bridge info links = self.nl.get_vlans() for link in links: self._interface_add(link) # for msg in self.nl.get_addr(): self._addr_add(msg) for msg in self.nl.get_neighbours(): self._neigh_add(msg) for msg in self.nl.get_routes(family=AF_INET): self._route_add(msg) for msg in self.nl.get_routes(family=AF_INET6): self._route_add(msg) for msg in self.nl.get_routes(family=AF_MPLS): self._route_add(msg) except Exception as e: logging.error('initdb error: %s', e) logging.error(traceback.format_exc()) try: self.nl.close() self.mnl.close() except: pass raise e def register_callback(self, callback, mode='post'): ''' IPDB callbacks are routines executed on a RT netlink message arrival. There are two types of callbacks: "post" and "pre" callbacks. ... "Post" callbacks are executed after the message is processed by IPDB and all corresponding objects are created or deleted. Using ipdb reference in "post" callbacks you will access the most up-to-date state of the IP database. "Post" callbacks are executed asynchronously in separate threads. These threads can work as long as you want them to. Callback threads are joined occasionally, so for a short time there can exist stopped threads. ... "Pre" callbacks are synchronous routines, executed before the message gets processed by IPDB. It gives you the way to patch arriving messages, but also places a restriction: until the callback exits, the main event IPDB loop is blocked. Normally, only "post" callbacks are required. But in some specific cases "pre" also can be useful. ... The routine, `register_callback()`, takes two arguments: - callback function - mode (optional, default="post") The callback should be a routine, that accepts three arguments:: cb(ipdb, msg, action) Arguments are: - **ipdb** is a reference to IPDB instance, that invokes the callback. - **msg** is a message arrived - **action** is just a msg['event'] field E.g., to work on a new interface, you should catch action == 'RTM_NEWLINK' and with the interface index (arrived in msg['index']) get it from IPDB:: index = msg['index'] interface = ipdb.interfaces[index] ''' lock = threading.Lock() def safe(*argv, **kwarg): with lock: callback(*argv, **kwarg) safe.hook = callback safe.lock = lock safe.uuid = uuid32() if mode == 'post': self._post_callbacks[safe.uuid] = safe elif mode == 'pre': self._pre_callbacks[safe.uuid] = safe return safe.uuid def unregister_callback(self, cuid, mode='post'): if mode == 'post': cbchain = self._post_callbacks elif mode == 'pre': cbchain = self._pre_callbacks else: raise KeyError('Unknown callback mode') safe = cbchain[cuid] with safe.lock: cbchain.pop(cuid) for t in tuple(self._cb_threads.get(cuid, ())): t.join(3) ret = self._cb_threads.get(cuid, ()) return ret def release(self): ''' Shutdown IPDB instance and sync the state. Since IPDB is asyncronous, some operations continue in the background, e.g. callbacks. So, prior to exit the script, it is required to properly shutdown IPDB. The shutdown sequence is not forced in an interactive python session, since it is easier for users and there is enough time to sync the state. But for the scripts the `release()` call is required. ''' with self.exclusive: with self._shutdown_lock: if self._stop: return self._stop = True # collect all the callbacks for cuid in tuple(self._cb_threads): for t in tuple(self._cb_threads[cuid]): t.join() # terminate the main loop try: for t in range(3): self.mnl.put({'index': 1}, RTM_GETLINK) self._mthread.join(t) if not self._mthread.is_alive(): break except Exception: # Just give up. # We can not handle this case pass self.nl.close() self.nl = None self.mnl.close() self.mnl = None # flush all the objects for (key, dev) in self.by_name.items(): try: self.detach(key, dev['index'], dev.nlmsg) except KeyError: pass def flush(idx): for key in tuple(idx.keys()): try: del idx[key] except KeyError: pass idx_list = [self.routes.tables[x] for x in self.routes.tables.keys()] idx_list.append(self.ipaddr) idx_list.append(self.neighbours) for idx in idx_list: flush(idx) def create(self, kind, ifname, reuse=False, **kwarg): with self.exclusive: # check for existing interface if ifname in self.interfaces: if (self.interfaces[ifname]['ipdb_scope'] == 'shadow') \ or reuse: device = self.interfaces[ifname] kwarg['kind'] = kind device.load_dict(kwarg) if self.interfaces[ifname]['ipdb_scope'] == 'shadow': with device._direct_state: device['ipdb_scope'] = 'create' else: raise CreateException("interface %s exists" % ifname) else: device = \ self.interfaces[ifname] = \ self.iclass(ipdb=self, mode='snapshot') device.update(kwarg) if isinstance(kwarg.get('link', None), Interface): device['link'] = kwarg['link']['index'] if isinstance(kwarg.get('vxlan_link', None), Interface): device['vxlan_link'] = kwarg['vxlan_link']['index'] device['kind'] = kind device['index'] = kwarg.get('index', 0) device['ifname'] = ifname device['ipdb_scope'] = 'create' device._mode = self.mode device.begin() return device def commit(self, transactions=None, rollback=False): # what to commit: either from transactions argument, or from # started transactions on existing objects if transactions is None: # collect interface transactions txlist = [(x, x.current_tx) for x in self.by_name.values() if x.local_tx.values()] # collect route transactions for table in self.routes.tables.keys(): txlist.extend([(x, x.current_tx) for x in self.routes.tables[table] if x.local_tx.values()]) txlist = sorted(txlist, key=lambda x: x[1]['ipdb_priority'], reverse=True) transactions = txlist snapshots = [] removed = [] try: for (target, tx) in transactions: if target['ipdb_scope'] == 'detached': continue if tx['ipdb_scope'] == 'remove': tx['ipdb_scope'] = 'shadow' removed.append((target, tx)) if not rollback: s = (target, target.pick(detached=True)) snapshots.append(s) target.commit(transaction=tx, rollback=rollback) except Exception: if not rollback: self.fallen = transactions self.commit(transactions=snapshots, rollback=True) raise else: if not rollback: for (target, tx) in removed: target['ipdb_scope'] = 'detached' target.detach() finally: if not rollback: for (target, tx) in transactions: target.drop(tx.uid) def _interface_del(self, msg): target = self.interfaces.get(msg['index']) if target is None: return for record in self.routes.filter({'oif': msg['index']}): with record['route']._direct_state: record['route']['ipdb_scope'] = 'gc' for record in self.routes.filter({'iif': msg['index']}): with record['route']._direct_state: record['route']['ipdb_scope'] = 'gc' target.nlmsg = msg # check for freezed devices if getattr(target, '_freeze', None): with target._direct_state: target['ipdb_scope'] = 'shadow' return # check for locked devices if target.get('ipdb_scope') in ('locked', 'shadow'): return self.detach(None, msg['index'], msg) def _interface_add(self, msg, skip_slaves=False): # check, if a record exists index = msg.get('index', None) ifname = msg.get_attr('IFLA_IFNAME', None) # scenario #1: no matches for both: new interface # scenario #2: ifname exists, index doesn't: index changed # scenario #3: index exists, ifname doesn't: name changed # scenario #4: both exist: assume simple update and # an optional name change if ((index not in self.interfaces) and (ifname not in self.interfaces)): # scenario #1, new interface device = \ self.interfaces[index] = \ self.interfaces[ifname] = self.iclass(ipdb=self) elif ((index not in self.interfaces) and (ifname in self.interfaces)): # scenario #2, index change old_index = self.interfaces[ifname]['index'] device = self.interfaces[index] = self.interfaces[ifname] if old_index in self.interfaces: del self.interfaces[old_index] if old_index in self.ipaddr: self.ipaddr[index] = self.ipaddr[old_index] del self.ipaddr[old_index] if old_index in self.neighbours: self.neighbours[index] = self.neighbours[old_index] del self.neighbours[old_index] else: # scenario #3, interface rename # scenario #4, assume rename old_name = self.interfaces[index]['ifname'] if old_name != ifname: # unlink old name del self.interfaces[old_name] device = self.interfaces[ifname] = self.interfaces[index] if index not in self.ipaddr: # for interfaces, created by IPDB self.ipaddr[index] = IPaddrSet() if index not in self.neighbours: self.neighbours[index] = LinkedSet() device.load_netlink(msg) if not skip_slaves: self.update_slaves(msg) def detach(self, name, idx, msg=None): with self.exclusive: if msg is not None: try: self.update_slaves(msg) except KeyError: pass if msg['event'] == 'RTM_DELLINK' and \ msg['change'] != 0xffffffff: return if idx is None or idx < 1: target = self.interfaces[name] idx = target['index'] else: target = self.interfaces[idx] name = target['ifname'] self.interfaces.pop(name, None) self.interfaces.pop(idx, None) self.ipaddr.pop(idx, None) self.neighbours.pop(idx, None) with target._direct_state: target['ipdb_scope'] = 'detached' def watchdog(self, action='RTM_NEWLINK', **kwarg): return Watchdog(self, action, kwarg) def _route_add(self, msg): self.routes.load_netlink(msg) def _route_del(self, msg): self.routes.load_netlink(msg) def update_slaves(self, msg): # Update slaves list -- only after update IPDB! index = msg['index'] master_index = msg.get_attr('IFLA_MASTER') if index == master_index: # one special case: links() call with AF_BRIDGE # returns IFLA_MASTER == index return master = self.interfaces.get(master_index, None) # there IS a master for the interface if master is not None: if msg['event'] == 'RTM_NEWLINK': # TODO tags: ipdb # The code serves one particular case, when # an enslaved interface is set to belong to # another master. In this case there will be # no 'RTM_DELLINK', only 'RTM_NEWLINK', and # we can end up in a broken state, when two # masters refers to the same slave for device in self.by_index: if index in self.interfaces[device]['ports']: try: with self.interfaces[device]._direct_state: self.interfaces[device].del_port(index) except KeyError: pass with master._direct_state: master.add_port(index) elif msg['event'] == 'RTM_DELLINK': if index in master['ports']: with master._direct_state: master.del_port(index) # there is NO masters for the interface, clean them if any else: device = self.interfaces[msg['index']] # clean vlan list from the port for vlan in tuple(device['vlans']): with device._direct_state: device.del_vlan(vlan) # clean device from ports for master in self.by_index: if index in self.interfaces[master]['ports']: try: with self.interfaces[master]._direct_state: self.interfaces[master].del_port(index) except KeyError: pass master = device.if_master if master is not None: if 'master' in device: with device._direct_state: device['master'] = None if (master in self.interfaces) and \ (msg['index'] in self.interfaces[master]['ports']): try: with self.interfaces[master]._direct_state: self.interfaces[master].del_port(index) except KeyError: pass def _addr_add(self, msg): if msg['family'] == AF_INET: addr = msg.get_attr('IFA_LOCAL') elif msg['family'] == AF_INET6: addr = msg.get_attr('IFA_ADDRESS') else: return raw = {'local': msg.get_attr('IFA_LOCAL'), 'broadcast': msg.get_attr('IFA_BROADCAST'), 'address': msg.get_attr('IFA_ADDRESS'), 'flags': msg.get_attr('IFA_FLAGS'), 'prefixlen': msg['prefixlen']} try: self.ipaddr[msg['index']].add(key=(addr, raw['prefixlen']), raw=raw) except: pass def _addr_del(self, msg): if msg['family'] == AF_INET: addr = msg.get_attr('IFA_LOCAL') elif msg['family'] == AF_INET6: addr = msg.get_attr('IFA_ADDRESS') else: return try: self.ipaddr[msg['index']].remove((addr, msg['prefixlen'])) except: pass def _neigh_add(self, msg): if msg['family'] == AF_BRIDGE: return try: (self .neighbours[msg['ifindex']] .add(key=msg.get_attr('NDA_DST'), raw={'lladdr': msg.get_attr('MDA_LLADDR')})) except: pass def _neigh_del(self, msg): if msg['family'] == AF_BRIDGE: return try: (self .neighbours[msg['ifindex']] .remove(msg.get_attr('NDA_DST'))) except: pass def serve_forever(self): ### # Main monitoring cycle. It gets messages from the # default iproute queue and updates objects in the # database. # # Should not be called manually. ### event_map = {'RTM_NEWLINK': self._interface_add, 'RTM_DELLINK': self._interface_del, 'RTM_NEWADDR': self._addr_add, 'RTM_DELADDR': self._addr_del, 'RTM_NEWNEIGH': self._neigh_add, 'RTM_DELNEIGH': self._neigh_del, 'RTM_NEWROUTE': self._route_add, 'RTM_DELROUTE': self._route_del} while not self._stop: try: messages = self.mnl.get() ## # Check it again # # NOTE: one should not run callbacks or # anything like that after setting the # _stop flag, since IPDB is not valid # anymore if self._stop: break except: logging.error('Restarting IPDB instance after ' 'error:\n%s', traceback.format_exc()) if self.restart_on_error: try: self.initdb() except: logging.error('Error restarting DB:\n%s', traceback.format_exc()) return continue else: raise RuntimeError('Emergency shutdown') for msg in messages: # Run pre-callbacks # NOTE: pre-callbacks are synchronous for (cuid, cb) in tuple(self._pre_callbacks.items()): try: cb(self, msg, msg['event']) except: pass with self.exclusive: event = msg.get('event', None) if event in event_map: event_map[event](msg) # run post-callbacks # NOTE: post-callbacks are asynchronous for (cuid, cb) in tuple(self._post_callbacks.items()): t = threading.Thread(name="callback %s" % (id(cb)), target=cb, args=(self, msg, msg['event'])) t.start() if cuid not in self._cb_threads: self._cb_threads[cuid] = set() self._cb_threads[cuid].add(t) # occasionally join cb threads for cuid in tuple(self._cb_threads): for t in tuple(self._cb_threads.get(cuid, ())): t.join(0) if not t.is_alive(): try: self._cb_threads[cuid].remove(t) except KeyError: pass if len(self._cb_threads.get(cuid, ())) == 0: del self._cb_threads[cuid]