class TestModule(Module): _default_serverlist = ['tcp://localhost:3181/','tcp://localhost:3182/','tcp://localhost:3183/'] def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.client = ZooKeeperClient(self.apiroutine, self.serverlist) self.connections.append(self.client) self.apiroutine.main = self.main self.routines.append(self.apiroutine) def watcher(self): watcher = ZooKeeperWatcherEvent.createMatcher() while True: yield (watcher,) print('WatcherEvent: %r' % (dump(self.apiroutine.event.message),)) def main(self): def _watch(w): for m in w.wait(self.apiroutine): yield m print('Watcher returns:', dump(self.apiroutine.retvalue)) def _watchall(watchers): for w in watchers: if w is not None: self.apiroutine.subroutine(_watch(w)) self.apiroutine.subroutine(self.watcher(), False, daemon = True) up = ZooKeeperSessionStateChanged.createMatcher(ZooKeeperSessionStateChanged.CREATED, self.client) yield (up,) print('Connection is up: %r' % (self.client.currentserver,)) for m in self.client.requests([zk.create(b'/vlcptest', b'test'), zk.getdata(b'/vlcptest', True)], self.apiroutine): yield m print(self.apiroutine.retvalue) pprint(dump(self.apiroutine.retvalue[0])) _watchall(self.apiroutine.retvalue[3]) for m in self.apiroutine.waitWithTimeout(0.2): yield m for m in self.client.requests([zk.delete(b'/vlcptest'), zk.getdata(b'/vlcptest', watch = True)], self.apiroutine): yield m print(self.apiroutine.retvalue) pprint(dump(self.apiroutine.retvalue[0])) _watchall(self.apiroutine.retvalue[3]) for m in self.client.requests([zk.multi( zk.multi_create(b'/vlcptest2', b'test'), zk.multi_create(b'/vlcptest2/subtest', 'test2') ), zk.getchildren2(b'/vlcptest2', True)], self.apiroutine): yield m print(self.apiroutine.retvalue) pprint(dump(self.apiroutine.retvalue[0])) _watchall(self.apiroutine.retvalue[3]) for m in self.client.requests([zk.multi( zk.multi_delete(b'/vlcptest2/subtest'), zk.multi_delete(b'/vlcptest2')), zk.getchildren2(b'/vlcptest2', True)], self.apiroutine): yield m print(self.apiroutine.retvalue) pprint(dump(self.apiroutine.retvalue[0])) _watchall(self.apiroutine.retvalue[3])
class TestModule(Module): _default_url = 'tcp://localhost/' _default_sessiontimeout = 30 def __init__(self, server): Module.__init__(self, server) self.protocol = ZooKeeper() self.client = Client(self.url, self.protocol, self.scheduler) self.connections.append(self.client) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self.main self.routines.append(self.apiroutine) def watcher(self): watcher = ZooKeeperWatcherEvent.createMatcher(connection = self.client) while True: yield (watcher,) print('WatcherEvent: %r' % (dump(self.apiroutine.event.message),)) def main(self): self.apiroutine.subroutine(self.watcher(), False, daemon = True) up = ZooKeeperConnectionStateEvent.createMatcher(ZooKeeperConnectionStateEvent.UP, self.client) notconn = ZooKeeperConnectionStateEvent.createMatcher(ZooKeeperConnectionStateEvent.NOTCONNECTED, self.client) yield (up, notconn) if self.apiroutine.matcher is notconn: print('Not connected') return else: print('Connection is up: %r' % (self.client,)) # Handshake for m in self.protocol.handshake(self.client, zk.ConnectRequest( timeOut = int(self.sessiontimeout * 1000), passwd = b'\x00' * 16, # Why is it necessary... ), self.apiroutine, []): yield m for m in self.protocol.requests(self.client, [zk.create(b'/vlcptest', b'test'), zk.getdata(b'/vlcptest', True)], self.apiroutine): yield m pprint(dump(self.apiroutine.retvalue[0])) for m in self.apiroutine.waitWithTimeout(0.2): yield m for m in self.protocol.requests(self.client, [zk.delete(b'/vlcptest'), zk.getdata(b'/vlcptest', watch = True)], self.apiroutine): yield m pprint(dump(self.apiroutine.retvalue[0])) for m in self.protocol.requests(self.client, [zk.multi( zk.multi_create(b'/vlcptest2', b'test'), zk.multi_create(b'/vlcptest2/subtest', 'test2') ), zk.getchildren2(b'/vlcptest2', True)], self.apiroutine): yield m pprint(dump(self.apiroutine.retvalue[0])) for m in self.protocol.requests(self.client, [zk.multi( zk.multi_delete(b'/vlcptest2/subtest'), zk.multi_delete(b'/vlcptest2')), zk.getchildren2(b'/vlcptest2', True)], self.apiroutine): yield m pprint(dump(self.apiroutine.retvalue[0]))
class OpenflowManager(Module): ''' Manage Openflow Connections ''' service = True _default_vhostbind = None def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_conns self.routines.append(self.apiroutine) self.managed_conns = {} self.endpoint_conns = {} self.table_modules = set() self._acquiring = False self._acquire_updated = False self._lastacquire = None self._synchronized = False self.createAPI(api(self.getconnections, self.apiroutine), api(self.getconnection, self.apiroutine), api(self.waitconnection, self.apiroutine), api(self.getdatapathids, self.apiroutine), api(self.getalldatapathids, self.apiroutine), api(self.getallconnections, self.apiroutine), api(self.getconnectionsbyendpoint, self.apiroutine), api(self.getconnectionsbyendpointname, self.apiroutine), api(self.getendpoints, self.apiroutine), api(self.getallendpoints, self.apiroutine), api(self.acquiretable, self.apiroutine), api(self.unacquiretable, self.apiroutine), api(self.lastacquiredtables)) def _add_connection(self, conn): vhost = conn.protocol.vhost conns = self.managed_conns.setdefault( (vhost, conn.openflow_datapathid), []) remove = [] for i in range(0, len(conns)): if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid: ci = conns[i] remove = [ci] ep = _get_endpoint(ci) econns = self.endpoint_conns.get((vhost, ep)) if econns is not None: try: econns.remove(ci) except ValueError: pass if not econns: del self.endpoint_conns[(vhost, ep)] del conns[i] break conns.append(conn) ep = _get_endpoint(conn) econns = self.endpoint_conns.setdefault((vhost, ep), []) econns.append(conn) if self._lastacquire and conn.openflow_auxiliaryid == 0: self.apiroutine.subroutine(self._initialize_connection(conn)) return remove def _initialize_connection(self, conn): ofdef = conn.openflowdef flow_mod = ofdef.ofp_flow_mod(buffer_id=ofdef.OFP_NO_BUFFER, out_port=ofdef.OFPP_ANY, command=ofdef.OFPFC_DELETE) if hasattr(ofdef, 'OFPG_ANY'): flow_mod.out_group = ofdef.OFPG_ANY if hasattr(ofdef, 'OFPTT_ALL'): flow_mod.table_id = ofdef.OFPTT_ALL if hasattr(ofdef, 'ofp_match_oxm'): flow_mod.match = ofdef.ofp_match_oxm() cmds = [flow_mod] if hasattr(ofdef, 'ofp_group_mod'): group_mod = ofdef.ofp_group_mod(command=ofdef.OFPGC_DELETE, group_id=ofdef.OFPG_ALL) cmds.append(group_mod) for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m if hasattr(ofdef, 'ofp_instruction_goto_table'): # Create default flows vhost = conn.protocol.vhost if self._lastacquire and vhost in self._lastacquire: _, pathtable = self._lastacquire[vhost] cmds = [ ofdef.ofp_flow_mod(table_id=t[i][1], command=ofdef.OFPFC_ADD, priority=0, buffer_id=ofdef.OFP_NO_BUFFER, out_port=ofdef.OFPP_ANY, out_group=ofdef.OFPG_ANY, match=ofdef.ofp_match_oxm(), instructions=[ ofdef.ofp_instruction_goto_table( table_id=t[i + 1][1]) ]) for _, t in pathtable.items() for i in range(0, len(t) - 1) ] if cmds: for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m for m in self.apiroutine.waitForSend( FlowInitialize(conn, conn.openflow_datapathid, conn.protocol.vhost)): yield m def _acquire_tables(self): try: while self._acquire_updated: result = None exception = None # Delay the update so we are not updating table acquires for every module for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()): yield m yield (TableAcquireDelayEvent.createMatcher(), ) module_list = list(self.table_modules) self._acquire_updated = False try: for m in self.apiroutine.executeAll( (callAPI(self.apiroutine, module, 'gettablerequest', {}) for module in module_list)): yield m except QuitException: raise except Exception as exc: self._logger.exception('Acquiring table failed') exception = exc else: requests = [r[0] for r in self.apiroutine.retvalue] vhosts = set(vh for _, vhs in requests if vhs is not None for vh in vhs) vhost_result = {} # Requests should be list of (name, (ancester, ancester, ...), pathname) for vh in vhosts: graph = {} table_path = {} try: for r in requests: if r[1] is None or vh in r[1]: for name, ancesters, pathname in r[0]: if name in table_path: if table_path[name] != pathname: raise ValueError( "table conflict detected: %r can not be in two path: %r and %r" % (name, table_path[name], pathname)) else: table_path[name] = pathname if name not in graph: graph[name] = (set(ancesters), set()) else: graph[name][0].update(ancesters) for anc in ancesters: graph.setdefault( anc, (set(), set()))[1].add(name) except ValueError as exc: self._logger.error(str(exc)) exception = exc break else: sequences = [] def dfs_sort(current): sequences.append(current) for d in graph[current][1]: anc = graph[d][0] anc.remove(current) if not anc: dfs_sort(d) nopre_tables = sorted( [k for k, v in graph.items() if not v[0]], key=lambda x: (table_path.get(name, ''), name)) for t in nopre_tables: dfs_sort(t) if len(sequences) < len(graph): rest_tables = set( graph.keys()).difference(sequences) self._logger.error( "Circle detected in table acquiring, following tables are related: %r, vhost = %r", sorted(rest_tables), vh) self._logger.error( "Circle dependencies are: %s", ", ".join( repr(tuple(graph[t][0])) + "=>" + t for t in rest_tables)) exception = ValueError( "Circle detected in table acquiring, following tables are related: %r, vhost = %r" % (sorted(rest_tables), vh)) break elif len(sequences) > 255: self._logger.error( "Table limit exceeded: %d tables (only 255 allowed), vhost = %r", len(sequences), vh) exception = ValueError( "Table limit exceeded: %d tables (only 255 allowed), vhost = %r" % (len(sequences), vh)) break else: full_indices = list( zip(sequences, itertools.count())) tables = dict( (k, tuple(g)) for k, g in itertools.groupby( sorted(full_indices, key=lambda x: table_path.get( x[0], '')), lambda x: table_path.get(x[0], ''))) vhost_result[vh] = (full_indices, tables) finally: self._acquiring = False if exception: for m in self.apiroutine.waitForSend( TableAcquireUpdate(exception=exception)): yield m else: result = vhost_result if result != self._lastacquire: self._lastacquire = result self._reinitall() for m in self.apiroutine.waitForSend( TableAcquireUpdate(result=result)): yield m def load(self, container): self.scheduler.queue.addSubQueue( 1, TableAcquireDelayEvent.createMatcher(), 'ofpmanager_tableacquiredelay') for m in container.waitForSend(TableAcquireUpdate(result=None)): yield m for m in Module.load(self, container): yield m def unload(self, container, force=False): for m in Module.unload(self, container, force=force): yield m for m in container.syscall( syscall_removequeue(self.scheduler.queue, 'ofpmanager_tableacquiredelay')): yield m def _reinitall(self): for cl in self.managed_conns.values(): for c in cl: self.apiroutine.subroutine(self._initialize_connection(c)) def _manage_existing(self): for m in callAPI(self.apiroutine, "openflowserver", "getconnections", {}): yield m vb = self.vhostbind for c in self.apiroutine.retvalue: if vb is None or c.protocol.vhost in vb: self._add_connection(c) self._synchronized = True for m in self.apiroutine.waitForSend( ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'), ) def _manage_conns(self): vb = self.vhostbind self.apiroutine.subroutine(self._manage_existing(), False) try: if vb is not None: conn_up = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch=lambda x: x.createby.vhost in vb) conn_down = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_DOWN, _ismatch=lambda x: x.createby.vhost in vb) else: conn_up = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_SETUP) conn_down = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_DOWN) while True: yield (conn_up, conn_down) if self.apiroutine.matcher is conn_up: e = self.apiroutine.event remove = self._add_connection(e.connection) self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', add=[e.connection], remove=remove)) else: e = self.apiroutine.event conns = self.managed_conns.get( (e.createby.vhost, e.datapathid)) remove = [] if conns is not None: try: conns.remove(e.connection) except ValueError: pass else: remove.append(e.connection) if not conns: del self.managed_conns[(e.createby.vhost, e.datapathid)] # Also delete from endpoint_conns ep = _get_endpoint(e.connection) econns = self.endpoint_conns.get( (e.createby.vhost, ep)) if econns is not None: try: econns.remove(e.connection) except ValueError: pass if not econns: del self.endpoint_conns[(e.createby.vhost, ep)] if remove: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', add=[], remove=remove)) finally: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'unsynchronized')) def getconnections(self, datapathid, vhost=''): "Return all connections of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list( self.managed_conns.get((vhost, datapathid), [])) def getconnection(self, datapathid, auxiliaryid=0, vhost=''): "Get current connection of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid, vhost) def _getconnection(self, datapathid, auxiliaryid=0, vhost=''): conns = self.managed_conns.get((vhost, datapathid)) if conns is None: return None else: for c in conns: if c.openflow_auxiliaryid == auxiliaryid: return c return None def waitconnection(self, datapathid, auxiliaryid=0, timeout=30, vhost=''): "Wait for a datapath connection" for m in self._wait_for_sync(): yield m c = self._getconnection(datapathid, auxiliaryid, vhost) if c is None: for m in self.apiroutine.waitWithTimeout( timeout, OpenflowConnectionStateEvent.createMatcher( datapathid, auxiliaryid, OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch=lambda x: x.createby.vhost == vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException( 'Datapath %016x is not connected' % datapathid) self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getdatapathids(self, vhost=''): "Get All datapath IDs" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ k[1] for k in self.managed_conns.keys() if k[0] == vhost ] def getalldatapathids(self): "Get all datapath IDs from any vhost. Return ``(vhost, datapathid)`` pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.keys()) def getallconnections(self, vhost=''): "Get all connections from vhost. If vhost is None, return all connections from any host" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = list( itertools.chain(self.managed_conns.values())) else: self.apiroutine.retvalue = list( itertools.chain(v for k, v in self.managed_conns.items() if k[0] == vhost)) def getconnectionsbyendpoint(self, endpoint, vhost=''): "Get connection by endpoint address (IP, IPv6 or UNIX socket address)" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint)) def getconnectionsbyendpointname(self, name, vhost='', timeout=30): "Get connection by endpoint name (Domain name, IP or IPv6 address)" # Resolve the name if not name: endpoint = '' for m in self.getconnectionbyendpoint(endpoint, vhost): yield m else: request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) # Resolve hostname for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)): yield m for m in self.apiroutine.waitWithTimeout( timeout, ResolveResponseEvent.createMatcher(request)): yield m if self.apiroutine.timeout: # Resolve is only allowed through asynchronous resolver #try: # self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST) #except: raise IOError('Resolve hostname timeout: ' + name) else: if hasattr(self.apiroutine.event, 'error'): raise IOError('Cannot resolve hostname: ' + name) resp = self.apiroutine.event.response for r in resp: raddr = r[4] if isinstance(raddr, tuple): # Ignore port endpoint = raddr[0] else: # Unix socket? This should not happen, but in case... endpoint = raddr for m in self.getconnectionsbyendpoint(endpoint, vhost): yield m if self.apiroutine.retvalue is not None: break def getendpoints(self, vhost=''): "Get all endpoints for vhost" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ k[1] for k in self.endpoint_conns if k[0] == vhost ] def getallendpoints(self): "Get all endpoints from any vhost. Return ``(vhost, endpoint)`` pairs." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.endpoint_conns.keys()) def lastacquiredtables(self, vhost=""): "Get acquired table IDs" return self._lastacquire.get(vhost) def acquiretable(self, modulename): "Start to acquire tables for a module on module loading." if not modulename in self.table_modules: self.table_modules.add(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield def unacquiretable(self, modulename): "When module is unloaded, stop acquiring tables for this module." if modulename in self.table_modules: self.table_modules.remove(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield
class OpenflowManager(Module): ''' Manage Openflow Connections ''' service = True _default_vhostbind = None def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_conns self.routines.append(self.apiroutine) self.managed_conns = {} self.endpoint_conns = {} self.table_modules = set() self._acquiring = False self._acquire_updated = False self._lastacquire = None self._synchronized = False self.createAPI(api(self.getconnections, self.apiroutine), api(self.getconnection, self.apiroutine), api(self.waitconnection, self.apiroutine), api(self.getdatapathids, self.apiroutine), api(self.getalldatapathids, self.apiroutine), api(self.getallconnections, self.apiroutine), api(self.getconnectionsbyendpoint, self.apiroutine), api(self.getconnectionsbyendpointname, self.apiroutine), api(self.getendpoints, self.apiroutine), api(self.getallendpoints, self.apiroutine), api(self.acquiretable, self.apiroutine), api(self.unacquiretable, self.apiroutine), api(self.lastacquiredtables) ) def _add_connection(self, conn): vhost = conn.protocol.vhost conns = self.managed_conns.setdefault((vhost, conn.openflow_datapathid), []) remove = [] for i in range(0, len(conns)): if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid: ci = conns[i] remove = [ci] ep = _get_endpoint(ci) econns = self.endpoint_conns.get((vhost, ep)) if econns is not None: try: econns.remove(ci) except ValueError: pass if not econns: del self.endpoint_conns[(vhost, ep)] del conns[i] break conns.append(conn) ep = _get_endpoint(conn) econns = self.endpoint_conns.setdefault((vhost, ep), []) econns.append(conn) if self._lastacquire and conn.openflow_auxiliaryid == 0: self.apiroutine.subroutine(self._initialize_connection(conn)) return remove def _initialize_connection(self, conn): ofdef = conn.openflowdef flow_mod = ofdef.ofp_flow_mod(buffer_id = ofdef.OFP_NO_BUFFER, out_port = ofdef.OFPP_ANY, command = ofdef.OFPFC_DELETE ) if hasattr(ofdef, 'OFPG_ANY'): flow_mod.out_group = ofdef.OFPG_ANY if hasattr(ofdef, 'OFPTT_ALL'): flow_mod.table_id = ofdef.OFPTT_ALL if hasattr(ofdef, 'ofp_match_oxm'): flow_mod.match = ofdef.ofp_match_oxm() cmds = [flow_mod] if hasattr(ofdef, 'ofp_group_mod'): group_mod = ofdef.ofp_group_mod(command = ofdef.OFPGC_DELETE, group_id = ofdef.OFPG_ALL ) cmds.append(group_mod) for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m if hasattr(ofdef, 'ofp_instruction_goto_table'): # Create default flows vhost = conn.protocol.vhost if self._lastacquire and vhost in self._lastacquire: _, pathtable = self._lastacquire[vhost] cmds = [ofdef.ofp_flow_mod(table_id = t[i][1], command = ofdef.OFPFC_ADD, priority = 0, buffer_id = ofdef.OFP_NO_BUFFER, out_port = ofdef.OFPP_ANY, out_group = ofdef.OFPG_ANY, match = ofdef.ofp_match_oxm(), instructions = [ofdef.ofp_instruction_goto_table(table_id = t[i+1][1])] ) for _,t in pathtable.items() for i in range(0, len(t) - 1)] if cmds: for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m for m in self.apiroutine.waitForSend(FlowInitialize(conn, conn.openflow_datapathid, conn.protocol.vhost)): yield m def _acquire_tables(self): try: while self._acquire_updated: result = None exception = None # Delay the update so we are not updating table acquires for every module for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()): yield m yield (TableAcquireDelayEvent.createMatcher(),) module_list = list(self.table_modules) self._acquire_updated = False try: for m in self.apiroutine.executeAll((callAPI(self.apiroutine, module, 'gettablerequest', {}) for module in module_list)): yield m except QuitException: raise except Exception as exc: self._logger.exception('Acquiring table failed') exception = exc else: requests = [r[0] for r in self.apiroutine.retvalue] vhosts = set(vh for _, vhs in requests if vhs is not None for vh in vhs) vhost_result = {} # Requests should be list of (name, (ancester, ancester, ...), pathname) for vh in vhosts: graph = {} table_path = {} try: for r in requests: if r[1] is None or vh in r[1]: for name, ancesters, pathname in r[0]: if name in table_path: if table_path[name] != pathname: raise ValueError("table conflict detected: %r can not be in two path: %r and %r" % (name, table_path[name], pathname)) else: table_path[name] = pathname if name not in graph: graph[name] = (set(ancesters), set()) else: graph[name][0].update(ancesters) for anc in ancesters: graph.setdefault(anc, (set(), set()))[1].add(name) except ValueError as exc: self._logger.error(str(exc)) exception = exc break else: sequences = [] def dfs_sort(current): sequences.append(current) for d in graph[current][1]: anc = graph[d][0] anc.remove(current) if not anc: dfs_sort(d) nopre_tables = sorted([k for k,v in graph.items() if not v[0]], key = lambda x: (table_path.get(name, ''),name)) for t in nopre_tables: dfs_sort(t) if len(sequences) < len(graph): rest_tables = set(graph.keys()).difference(sequences) self._logger.error("Circle detected in table acquiring, following tables are related: %r, vhost = %r", sorted(rest_tables), vh) self._logger.error("Circle dependencies are: %s", ", ".join(repr(tuple(graph[t][0])) + "=>" + t for t in rest_tables)) exception = ValueError("Circle detected in table acquiring, following tables are related: %r, vhost = %r" % (sorted(rest_tables),vh)) break elif len(sequences) > 255: self._logger.error("Table limit exceeded: %d tables (only 255 allowed), vhost = %r", len(sequences), vh) exception = ValueError("Table limit exceeded: %d tables (only 255 allowed), vhost = %r" % (len(sequences),vh)) break else: full_indices = list(zip(sequences, itertools.count())) tables = dict((k,tuple(g)) for k,g in itertools.groupby(sorted(full_indices, key = lambda x: table_path.get(x[0], '')), lambda x: table_path.get(x[0], ''))) vhost_result[vh] = (full_indices, tables) finally: self._acquiring = False if exception: for m in self.apiroutine.waitForSend(TableAcquireUpdate(exception = exception)): yield m else: result = vhost_result if result != self._lastacquire: self._lastacquire = result self._reinitall() for m in self.apiroutine.waitForSend(TableAcquireUpdate(result = result)): yield m def load(self, container): self.scheduler.queue.addSubQueue(1, TableAcquireDelayEvent.createMatcher(), 'ofpmanager_tableacquiredelay') for m in container.waitForSend(TableAcquireUpdate(result = None)): yield m for m in Module.load(self, container): yield m def unload(self, container, force=False): for m in Module.unload(self, container, force=force): yield m for m in container.syscall(syscall_removequeue(self.scheduler.queue, 'ofpmanager_tableacquiredelay')): yield m def _reinitall(self): for cl in self.managed_conns.values(): for c in cl: self.apiroutine.subroutine(self._initialize_connection(c)) def _manage_existing(self): for m in callAPI(self.apiroutine, "openflowserver", "getconnections", {}): yield m vb = self.vhostbind for c in self.apiroutine.retvalue: if vb is None or c.protocol.vhost in vb: self._add_connection(c) self._synchronized = True for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),) def _manage_conns(self): vb = self.vhostbind self.apiroutine.subroutine(self._manage_existing(), False) try: if vb is not None: conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch = lambda x: x.createby.vhost in vb) conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN, _ismatch = lambda x: x.createby.vhost in vb) else: conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP) conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN) while True: yield (conn_up, conn_down) if self.apiroutine.matcher is conn_up: e = self.apiroutine.event remove = self._add_connection(e.connection) self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [e.connection], remove = remove)) else: e = self.apiroutine.event conns = self.managed_conns.get((e.createby.vhost, e.datapathid)) remove = [] if conns is not None: try: conns.remove(e.connection) except ValueError: pass else: remove.append(e.connection) if not conns: del self.managed_conns[(e.createby.vhost, e.datapathid)] # Also delete from endpoint_conns ep = _get_endpoint(e.connection) econns = self.endpoint_conns.get((e.createby.vhost, ep)) if econns is not None: try: econns.remove(e.connection) except ValueError: pass if not econns: del self.endpoint_conns[(e.createby.vhost, ep)] if remove: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [], remove = remove)) finally: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized')) def getconnections(self, datapathid, vhost = ''): "Return all connections of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.get((vhost, datapathid), [])) def getconnection(self, datapathid, auxiliaryid = 0, vhost = ''): "Get current connection of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid, vhost) def _getconnection(self, datapathid, auxiliaryid = 0, vhost = ''): conns = self.managed_conns.get((vhost, datapathid)) if conns is None: return None else: for c in conns: if c.openflow_auxiliaryid == auxiliaryid: return c return None def waitconnection(self, datapathid, auxiliaryid = 0, timeout = 30, vhost = ''): "Wait for a datapath connection" for m in self._wait_for_sync(): yield m c = self._getconnection(datapathid, auxiliaryid, vhost) if c is None: for m in self.apiroutine.waitWithTimeout(timeout, OpenflowConnectionStateEvent.createMatcher(datapathid, auxiliaryid, OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch = lambda x: x.createby.vhost == vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath %016x is not connected' % datapathid) self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getdatapathids(self, vhost = ''): "Get All datapath IDs" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost] def getalldatapathids(self): "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.keys()) def getallconnections(self, vhost = ''): "Get all connections from vhost. If vhost is None, return all connections from any host" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = list(itertools.chain(self.managed_conns.values())) else: self.apiroutine.retvalue = list(itertools.chain(v for k,v in self.managed_conns.items() if k[0] == vhost)) def getconnectionsbyendpoint(self, endpoint, vhost = ''): "Get connection by endpoint address (IP, IPv6 or UNIX socket address)" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint)) def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30): "Get connection by endpoint name (Domain name, IP or IPv6 address)" # Resolve the name if not name: endpoint = '' for m in self.getconnectionbyendpoint(endpoint, vhost): yield m else: request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) # Resolve hostname for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)): yield m for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)): yield m if self.apiroutine.timeout: # Resolve is only allowed through asynchronous resolver #try: # self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST) #except: raise IOError('Resolve hostname timeout: ' + name) else: if hasattr(self.apiroutine.event, 'error'): raise IOError('Cannot resolve hostname: ' + name) resp = self.apiroutine.event.response for r in resp: raddr = r[4] if isinstance(raddr, tuple): # Ignore port endpoint = raddr[0] else: # Unix socket? This should not happen, but in case... endpoint = raddr for m in self.getconnectionsbyendpoint(endpoint, vhost): yield m if self.apiroutine.retvalue is not None: break def getendpoints(self, vhost = ''): "Get all endpoints for vhost" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost] def getallendpoints(self): "Get all endpoints from any vhost. Return (vhost, endpoint) pairs." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.endpoint_conns.keys()) def lastacquiredtables(self, vhost = ""): "Get acquired table IDs" return self._lastacquire.get(vhost) def acquiretable(self, modulename): "Start to acquire tables for a module on module loading." if not modulename in self.table_modules: self.table_modules.add(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield def unacquiretable(self, modulename): "When module is unloaded, stop acquiring tables for this module." if modulename in self.table_modules: self.table_modules.remove(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield
class ObjectDB(Module): """ Abstract transaction layer for KVDB """ service = True # Priority for object update event _default_objectupdatepriority = 450 # Enable debugging mode for updater: all updaters will be called for an extra time # to make sure it does not crash with multiple calls _default_debuggingupdater = False def __init__(self, server): Module.__init__(self, server) self._managed_objs = {} self._watches = {} self._watchedkeys = set() self._requests = [] self._transactno = 0 self._stale = False self._updatekeys = set() self._update_version = {} self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._update self.routines.append(self.apiroutine) self.createAPI(api(self.mget, self.apiroutine), api(self.get, self.apiroutine), api(self.mgetonce, self.apiroutine), api(self.getonce, self.apiroutine), api(self.mwatch, self.apiroutine), api(self.watch, self.apiroutine), api(self.munwatch, self.apiroutine), api(self.unwatch, self.apiroutine), api(self.unwatchall, self.apiroutine), api(self.transact, self.apiroutine), api(self.watchlist), api(self.walk, self.apiroutine) ) def load(self, container): self.scheduler.queue.addSubQueue(\ self.objectupdatepriority, dataobj.DataObjectUpdateEvent.createMatcher(), 'dataobjectupdate') for m in callAPI(container, 'updatenotifier', 'createnotifier'): yield m self._notifier = container.retvalue for m in Module.load(self, container): yield m self.routines.append(self._notifier) def unload(self, container, force=False): for m in container.syscall(syscall_removequeue(self.scheduler.queue, 'dataobjectupdate')): yield m for m in Module.unload(self, container, force=force): yield m def _update(self): timestamp = '%012x' % (int(time() * 1000),) + '-' notification_matcher = self._notifier.notification_matcher(False) def copywithkey(obj, key): newobj = deepcopy(obj) if hasattr(newobj, 'setkey'): newobj.setkey(key) return newobj def getversion(obj): if obj is None: return (0, -1) else: return (getattr(obj, 'kvdb_createtime', 0), getattr(obj, 'kvdb_updateversion', 0)) def isnewer(obj, version): if obj is None: return version[1] != -1 else: return getversion(obj) > version request_matcher = RetrieveRequestSend.createMatcher() def onupdate(event, matcher): update_keys = self._watchedkeys.intersection([_str(k) for k in event.keys]) self._updatekeys.update(update_keys) if event.extrainfo: for k,v in zip(event.keys, event.extrainfo): k = _str(k) if k in update_keys: v = tuple(v) oldv = self._update_version.get(k, (0, -1)) if oldv < v: self._update_version[k] = v else: for k in event.keys: try: del self._update_version[_str(k)] except KeyError: pass def updateinner(): processing_requests = [] # New managed keys retrieve_list = set() orig_retrieve_list = set() retrieveonce_list = set() orig_retrieveonce_list = set() # Retrieved values are stored in update_result before merging into current storage update_result = {} # key => [(walker_func, original_keys, rid), ...] walkers = {} self._loopCount = 0 # A request-id -> retrieve set dictionary to store the saved keys savelist = {} def updateloop(): while (retrieve_list or self._updatekeys or self._requests): watch_keys = set() # Updated keys update_list = set() if self._loopCount >= 10 and not retrieve_list: if not self._updatekeys: break elif self._loopCount >= 100: # Too many updates, we must stop to respond self._logger.warning("There are still database updates after 100 loops of mget, respond with potential inconsistent values") break if self._updatekeys: update_list.update(self._updatekeys) self._updatekeys.clear() if self._requests: # Processing requests for r in self._requests: if r[2] == 'unwatch': try: for k in r[0]: s = self._watches.get(k) if s: s.discard(r[3]) if not s: del self._watches[k] # Do not need to wait except Exception as exc: for m in self.apiroutine.waitForSend(RetrieveReply(r[1], exception = exc)): yield m else: for m in self.apiroutine.waitForSend(RetrieveReply(r[1], result = None)): yield m elif r[2] == 'watch': retrieve_list.update(r[0]) orig_retrieve_list.update(r[0]) for k in r[0]: self._watches.setdefault(k, set()).add(r[3]) processing_requests.append(r) elif r[2] == 'get': retrieve_list.update(r[0]) orig_retrieve_list.update(r[0]) processing_requests.append(r) elif r[2] == 'walk': retrieve_list.update(r[0]) processing_requests.append(r) for k,v in r[3].items(): walkers.setdefault(k, []).append((v, (r[0], r[1]))) else: retrieveonce_list.update(r[0]) orig_retrieveonce_list.update(r[0]) processing_requests.append(r) del self._requests[:] if retrieve_list: watch_keys.update(retrieve_list) # Add watch_keys to notification watch_keys.difference_update(self._watchedkeys) if watch_keys: for k in watch_keys: if k in update_result: self._update_version[k] = getversion(update_result[k]) for m in self._notifier.add_listen(*tuple(watch_keys.difference(self._watchedkeys))): yield m self._watchedkeys.update(watch_keys) get_list_set = update_list.union(retrieve_list.union(retrieveonce_list).difference(self._managed_objs.keys()).difference(update_result.keys())) get_list = list(get_list_set) if get_list: try: for m in callAPI(self.apiroutine, 'kvstorage', 'mget', {'keys': get_list}): yield m except QuitException: raise except Exception: # Serve with cache if not self._stale: self._logger.warning('KVStorage retrieve failed, serve with cache', exc_info = True) self._stale = True # Discard all retrieved results update_result.clear() # Retry update later self._updatekeys.update(update_list) #break changed_set = set() else: result = self.apiroutine.retvalue self._stale = False for k,v in zip(get_list, result): if v is not None and hasattr(v, 'setkey'): v.setkey(k) if k in self._watchedkeys and k not in self._update_version: self._update_version[k] = getversion(v) changed_set = set(k for k,v in zip(get_list, result) if k not in update_result or getversion(v) != getversion(update_result[k])) update_result.update(zip(get_list, result)) else: changed_set = set() # All keys which should be retrieved in next loop new_retrieve_list = set() # Keys which should be retrieved in next loop for a single walk new_retrieve_keys = set() # Keys that are used in current walk will be retrieved again in next loop used_keys = set() # We separate the original data and new retrieved data space, and do not allow # cross usage, to prevent discontinue results def walk_original(key): if hasattr(key, 'getkey'): key = key.getkey() key = _str(key) if key not in self._watchedkeys: # This key is not retrieved, raise a KeyError, and record this key new_retrieve_keys.add(key) raise KeyError('Not retrieved') elif self._stale: if key not in self._managed_objs: new_retrieve_keys.add(key) else: used_keys.add(key) return self._managed_objs.get(key) elif key in changed_set: # We are retrieving from the old result, do not allow to use new data used_keys.add(key) new_retrieve_keys.add(key) raise KeyError('Not retrieved') elif key in update_result: used_keys.add(key) return update_result[key] elif key in self._managed_objs: used_keys.add(key) return self._managed_objs[key] else: # This key is not retrieved, raise a KeyError, and record this key new_retrieve_keys.add(key) raise KeyError('Not retrieved') def walk_new(key): if hasattr(key, 'getkey'): key = key.getkey() key = _str(key) if key not in self._watchedkeys: # This key is not retrieved, raise a KeyError, and record this key new_retrieve_keys.add(key) raise KeyError('Not retrieved') elif key in get_list_set: # We are retrieving from the new data used_keys.add(key) return update_result[key] elif key in self._managed_objs or key in update_result: # Do not allow the old data used_keys.add(key) new_retrieve_keys.add(key) raise KeyError('Not retrieved') else: # This key is not retrieved, raise a KeyError, and record this key new_retrieve_keys.add(key) raise KeyError('Not retrieved') def create_walker(orig_key): if self._stale: return walk_original elif orig_key in changed_set: return walk_new else: return walk_original walker_set = set() def default_walker(key, obj, walk): if key in walker_set: return else: walker_set.add(key) if hasattr(obj, 'kvdb_retrievelist'): rl = obj.kvdb_retrievelist() for k in rl: try: newobj = walk(k) except KeyError: pass else: if newobj is not None: default_walker(k, newobj, walk) for k in orig_retrieve_list: v = update_result.get(k) if v is not None: new_retrieve_keys.clear() used_keys.clear() default_walker(k, v, create_walker(k)) if new_retrieve_keys: new_retrieve_list.update(new_retrieve_keys) self._updatekeys.update(used_keys) self._updatekeys.add(k) savelist.clear() for k,ws in walkers.items(): # k: the walker key # ws: list of [walker_func, (request_original_keys, rid)] # Retry every walker, starts with k, with the value of v if k in update_result: # The value is newly retrieved v = update_result.get(k) else: # Use the stored value v = self._managed_objs.get(k) if ws: for w,r in list(ws): # w: walker_func # r: (request_original_keys, rid) # Custom walker def save(key): if hasattr(key, 'getkey'): key = key.getkey() key = _str(key) if key != k and key not in used_keys: raise ValueError('Cannot save a key without walk') savelist.setdefault(r[1], set()).add(key) try: new_retrieve_keys.clear() used_keys.clear() w(k, v, create_walker(k), save) except Exception as exc: # if one walker failed, the whole request is failed, remove all walkers self._logger.warning("A walker raises an exception which rolls back the whole walk process %r. " "walker = %r, start key = %r, new_retrieve_keys = %r, used_keys = %r", w, k, r[1], new_retrieve_keys, used_keys, exc_info=True) for orig_k in r[0]: if orig_k in walkers: walkers[orig_k][:] = [(w0, r0) for w0,r0 in walkers[orig_k] if r0[1] != r[1]] processing_requests[:] = [r0 for r0 in processing_requests if r0[1] != r[1]] savelist.pop(r[1]) for m in self.apiroutine.waitForSend(RetrieveReply(r[1], exception = exc)): yield m else: if new_retrieve_keys: new_retrieve_list.update(new_retrieve_keys) self._updatekeys.update(used_keys) self._updatekeys.add(k) for save in savelist.values(): for k in save: v = update_result.get(k) if v is not None: # If we retrieved a new value, we should also retrieved the references # from this value new_retrieve_keys.clear() used_keys.clear() default_walker(k, v, create_walker(k)) if new_retrieve_keys: new_retrieve_list.update(new_retrieve_keys) self._updatekeys.update(used_keys) self._updatekeys.add(k) retrieve_list.clear() retrieveonce_list.clear() retrieve_list.update(new_retrieve_list) self._loopCount += 1 if self._stale: watch_keys = set(retrieve_list) watch_keys.difference_update(self._watchedkeys) if watch_keys: for m in self._notifier.add_listen(*tuple(watch_keys)): yield m self._watchedkeys.update(watch_keys) break while True: for m in self.apiroutine.withCallback(updateloop(), onupdate, notification_matcher): yield m if self._loopCount >= 100 or self._stale: break # If some updated result is newer than the notification version, we should wait for the notification should_wait = False for k,v in update_result.items(): if k in self._watchedkeys: oldv = self._update_version.get(k) if oldv is not None and isnewer(v, oldv): should_wait = True break if should_wait: for m in self.apiroutine.waitWithTimeout(0.2, notification_matcher): yield m if self.apiroutine.timeout: break else: onupdate(self.apiroutine.event, self.apiroutine.matcher) else: break # Update result send_events = [] self._transactno += 1 transactid = '%s%016x' % (timestamp, self._transactno) update_objs = [] for k,v in update_result.items(): if k in self._watchedkeys: if v is None: oldv = self._managed_objs.get(k) if oldv is not None: if hasattr(oldv, 'kvdb_detach'): oldv.kvdb_detach() update_objs.append((k, oldv, dataobj.DataObjectUpdateEvent.DELETED)) else: update_objs.append((k, None, dataobj.DataObjectUpdateEvent.DELETED)) del self._managed_objs[k] else: oldv = self._managed_objs.get(k) if oldv is not None: if oldv != v: if oldv and hasattr(oldv, 'kvdb_update'): oldv.kvdb_update(v) update_objs.append((k, oldv, dataobj.DataObjectUpdateEvent.UPDATED)) else: if hasattr(oldv, 'kvdb_detach'): oldv.kvdb_detach() self._managed_objs[k] = v update_objs.append((k, v, dataobj.DataObjectUpdateEvent.UPDATED)) else: self._managed_objs[k] = v update_objs.append((k, v, dataobj.DataObjectUpdateEvent.UPDATED)) for k in update_result.keys(): v = self._managed_objs.get(k) if v is not None and hasattr(v, 'kvdb_retrievefinished'): v.kvdb_retrievefinished(self._managed_objs) allkeys = tuple(k for k,_,_ in update_objs) send_events.extend((dataobj.DataObjectUpdateEvent(k, transactid, t, object = v, allkeys = allkeys) for k,v,t in update_objs)) # Process requests for r in processing_requests: if r[2] == 'get': objs = [self._managed_objs.get(k) for k in r[0]] for k,v in zip(r[0], objs): if v is not None: self._watches.setdefault(k, set()).add(r[3]) result = [o.create_reference() if o is not None and hasattr(o, 'create_reference') else o for o in objs] elif r[2] == 'watch': result = [(v.create_reference() if hasattr(v, 'create_reference') else v) if v is not None else dataobj.ReferenceObject(k) for k,v in ((k,self._managed_objs.get(k)) for k in r[0])] elif r[2] == 'walk': saved_keys = list(savelist.get(r[1], [])) for k in saved_keys: self._watches.setdefault(k, set()).add(r[4]) objs = [self._managed_objs.get(k) for k in saved_keys] result = (saved_keys, [o.create_reference() if hasattr(o, 'create_reference') else o if o is not None else dataobj.ReferenceObject(k) for k,o in zip(saved_keys, objs)]) else: result = [copywithkey(update_result.get(k, self._managed_objs.get(k)), k) for k in r[0]] send_events.append(RetrieveReply(r[1], result = result, stale = self._stale)) # Use DFS to remove unwatched objects mark_set = set() def dfs(k): if k in mark_set: return mark_set.add(k) v = self._managed_objs.get(k) if v is not None and hasattr(v, 'kvdb_internalref'): for k2 in v.kvdb_internalref(): dfs(k2) for k in self._watches.keys(): dfs(k) def output_result(): remove_keys = self._watchedkeys.difference(mark_set) if remove_keys: self._watchedkeys.difference_update(remove_keys) for m in self._notifier.remove_listen(*tuple(remove_keys)): yield m for k in remove_keys: if k in self._managed_objs: del self._managed_objs[k] if k in self._update_version: del self._update_version[k] for e in send_events: for m in self.apiroutine.waitForSend(e): yield m for m in self.apiroutine.withCallback(output_result(), onupdate): yield m while True: if not self._updatekeys and not self._requests: yield (notification_matcher, request_matcher) if self.apiroutine.matcher is notification_matcher: onupdate(self.apiroutine.event, self.apiroutine.matcher) for m in updateinner(): yield m def mget(self, keys, requestid, nostale = False): "Get multiple objects and manage them. Return references to the objects." keys = tuple(_str2(k) for k in keys) notify = not self._requests rid = object() self._requests.append((keys, rid, 'get', requestid)) if notify: for m in self.apiroutine.waitForSend(RetrieveRequestSend()): yield m yield (RetrieveReply.createMatcher(rid),) if hasattr(self.apiroutine.event, 'exception'): raise self.apiroutine.event.exception if nostale and self.apiroutine.event.stale: raise StaleResultException(self.apiroutine.event.result) self.apiroutine.retvalue = self.apiroutine.event.result def get(self, key, requestid, nostale = False): """ Get an object from specified key, and manage the object. Return a reference to the object or None if not exists. """ for m in self.mget([key], requestid, nostale): yield m self.apiroutine.retvalue = self.apiroutine.retvalue[0] def mgetonce(self, keys, nostale = False): "Get multiple objects, return copies of them. Referenced objects are not retrieved." keys = tuple(_str2(k) for k in keys) notify = not self._requests rid = object() self._requests.append((keys, rid, 'getonce')) if notify: for m in self.apiroutine.waitForSend(RetrieveRequestSend()): yield m yield (RetrieveReply.createMatcher(rid),) if hasattr(self.apiroutine.event, 'exception'): raise self.apiroutine.event.exception if nostale and self.apiroutine.event.stale: raise StaleResultException(self.apiroutine.event.result) self.apiroutine.retvalue = self.apiroutine.event.result def getonce(self, key, nostale = False): "Get a object without manage it. Return a copy of the object, or None if not exists. Referenced objects are not retrieved." for m in self.mgetonce([key], nostale): yield m self.apiroutine.retvalue = self.apiroutine.retvalue[0] def watch(self, key, requestid, nostale = False): """ Try to find an object and return a reference. Use ``reference.isdeleted()`` to test whether the object exists. Use ``reference.wait(container)`` to wait for the object to be existed. """ for m in self.mwatch([key], requestid, nostale): yield m self.apiroutine.retvalue = self.apiroutine.retvalue[0] def mwatch(self, keys, requestid, nostale = False): "Try to return all the references, see ``watch()``" keys = tuple(_str2(k) for k in keys) notify = not self._requests rid = object() self._requests.append(keys, rid, 'watch', requestid) if notify: for m in self.apiroutine.waitForSend(RetrieveRequestSend()): yield m yield (RetrieveReply.createMatcher(rid),) if hasattr(self.apiroutine.event, 'exception'): raise self.apiroutine.event.exception if nostale and self.apiroutine.event.stale: raise StaleResultException(self.apiroutine.event.result) self.apiroutine.retvalue = self.apiroutine.event.result def unwatch(self, key, requestid): "Cancel management of a key" for m in self.munwatch([key], requestid): yield m self.apiroutine.retvalue = None def unwatchall(self, requestid): "Cancel management for all keys that are managed by requestid" keys = [k for k,v in self._watches.items() if requestid in v] for m in self.munwatch(keys, requestid): yield m def munwatch(self, keys, requestid): "Cancel management of keys" keys = tuple(_str2(k) for k in keys) notify = not self._requests rid = object() self._requests.append((keys, rid, 'unwatch', requestid)) if notify: for m in self.apiroutine.waitForSend(RetrieveRequestSend()): yield m yield (RetrieveReply.createMatcher(rid),) if hasattr(self.apiroutine.event, 'exception'): raise self.apiroutine.event.exception self.apiroutine.retvalue = None def transact(self, keys, updater, withtime = False): """ Try to update keys in a transact, with an ``updater(keys, values)``, which returns ``(updated_keys, updated_values)``. The updater may be called more than once. If ``withtime = True``, the updater should take three parameters: ``(keys, values, timestamp)`` with timestamp as the server time """ keys = tuple(_str2(k) for k in keys) updated_ref = [None, None] extra_keys = [] extra_key_set = [] auto_remove_keys = set() orig_len = len(keys) def updater_with_key(keys, values, timestamp): # Automatically manage extra keys remove_uniquekeys = [] remove_multikeys = [] update_uniquekeys = [] update_multikeys = [] keystart = orig_len + len(auto_remove_keys) for v in values[:keystart]: if v is not None: if hasattr(v, 'kvdb_uniquekeys'): remove_uniquekeys.extend((k,v.create_weakreference()) for k in v.kvdb_uniquekeys()) if hasattr(v, 'kvdb_multikeys'): remove_multikeys.extend((k,v.create_weakreference()) for k in v.kvdb_multikeys()) if self.debuggingupdater: # Updater may be called more than once, ensure that this updater does not crash # on multiple calls kc = keys[:orig_len] vc = [v.clone_instance() if v is not None and hasattr(v, 'clone_instance') else deepcopy(v) for v in values[:orig_len]] if withtime: updated_keys, updated_values = updater(kc, vc, timestamp) else: updated_keys, updated_values = updater(kc, vc) if withtime: updated_keys, updated_values = updater(keys[:orig_len], values[:orig_len], timestamp) else: updated_keys, updated_values = updater(keys[:orig_len], values[:orig_len]) for v in updated_values: if v is not None: if hasattr(v, 'kvdb_uniquekeys'): update_uniquekeys.extend((k,v.create_weakreference()) for k in v.kvdb_uniquekeys()) if hasattr(v, 'kvdb_multikeys'): update_multikeys.extend((k,v.create_weakreference()) for k in v.kvdb_multikeys()) extrakeysdict = dict(zip(keys[keystart:keystart + len(extra_keys)], values[keystart:keystart + len(extra_keys)])) extrakeysetdict = dict(zip(keys[keystart + len(extra_keys):keystart + len(extra_keys) + len(extra_key_set)], values[keystart + len(extra_keys):keystart + len(extra_keys) + len(extra_key_set)])) tempdict = {} old_values = dict(zip(keys, values)) updated_keyset = set(updated_keys) try: append_remove = set() autoremove_keys = set() # Use DFS to find auto remove keys def dfs(k): if k in autoremove_keys: return autoremove_keys.add(k) if k not in old_values: append_remove.add(k) else: oldv = old_values[k] if oldv is not None and hasattr(oldv, 'kvdb_autoremove'): for k2 in oldv.kvdb_autoremove(): dfs(k2) for k,v in zip(updated_keys, updated_values): if v is None: dfs(k) if append_remove: raise _NeedMoreKeysException() for k,v in remove_uniquekeys: if v.getkey() not in updated_keyset and v.getkey() not in auto_remove_keys: # This key is not updated, keep the indices untouched continue if k not in extrakeysdict: raise _NeedMoreKeysException() elif extrakeysdict[k] is not None and extrakeysdict[k].ref.getkey() == v.getkey(): # If the unique key does not reference to the correct object # there may be an error, but we ignore this. # Save in a temporary dictionary. We may restore it later. tempdict[k] = extrakeysdict[k] extrakeysdict[k] = None setkey = UniqueKeyReference.get_keyset_from_key(k) if setkey not in extrakeysetdict: raise _NeedMoreKeysException() else: ks = extrakeysetdict[setkey] if ks is None: ks = UniqueKeySet.create_from_key(setkey) extrakeysetdict[setkey] = ks ks.set.dataset().discard(WeakReferenceObject(k)) for k,v in remove_multikeys: if v.getkey() not in updated_keyset and v.getkey() not in auto_remove_keys: # This key is not updated, keep the indices untouched continue if k not in extrakeysdict: raise _NeedMoreKeysException() else: mk = extrakeysdict[k] if mk is not None: mk.set.dataset().discard(v) if not mk.set.dataset(): tempdict[k] = extrakeysdict[k] extrakeysdict[k] = None setkey = MultiKeyReference.get_keyset_from_key(k) if setkey not in extrakeysetdict: raise _NeedMoreKeysException() else: ks = extrakeysetdict[setkey] if ks is None: ks = MultiKeySet.create_from_key(setkey) extrakeysetdict[setkey] = ks ks.set.dataset().discard(WeakReferenceObject(k)) for k,v in update_uniquekeys: if k not in extrakeysdict: raise _NeedMoreKeysException() elif extrakeysdict[k] is not None and extrakeysdict[k].ref.getkey() != v.getkey(): raise AlreadyExistsException('Unique key conflict for %r and %r, with key %r' % \ (extrakeysdict[k].ref.getkey(), v.getkey(), k)) elif extrakeysdict[k] is None: lv = tempdict.get(k, None) if lv is not None and lv.ref.getkey() == v.getkey(): # Restore this value nv = lv else: nv = UniqueKeyReference.create_from_key(k) nv.ref = ReferenceObject(v.getkey()) extrakeysdict[k] = nv setkey = UniqueKeyReference.get_keyset_from_key(k) if setkey not in extrakeysetdict: raise _NeedMoreKeysException() else: ks = extrakeysetdict[setkey] if ks is None: ks = UniqueKeySet.create_from_key(setkey) extrakeysetdict[setkey] = ks ks.set.dataset().add(nv.create_weakreference()) for k,v in update_multikeys: if k not in extrakeysdict: raise _NeedMoreKeysException() else: mk = extrakeysdict[k] if mk is None: mk = tempdict.get(k, None) if mk is None: mk = MultiKeyReference.create_from_key(k) mk.set = DataObjectSet() setkey = MultiKeyReference.get_keyset_from_key(k) if setkey not in extrakeysetdict: raise _NeedMoreKeysException() else: ks = extrakeysetdict[setkey] if ks is None: ks = MultiKeySet.create_from_key(setkey) extrakeysetdict[setkey] = ks ks.set.dataset().add(mk.create_weakreference()) mk.set.dataset().add(v) extrakeysdict[k] = mk except _NeedMoreKeysException: # Prepare the keys extra_keys[:] = list(set(itertools.chain((k for k,v in remove_uniquekeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys), (k for k,v in remove_multikeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys), (k for k,_ in update_uniquekeys), (k for k,_ in update_multikeys)))) extra_key_set[:] = list(set(itertools.chain((UniqueKeyReference.get_keyset_from_key(k) for k,v in remove_uniquekeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys), (MultiKeyReference.get_keyset_from_key(k) for k,v in remove_multikeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys), (UniqueKeyReference.get_keyset_from_key(k) for k,_ in update_uniquekeys), (MultiKeyReference.get_keyset_from_key(k) for k,_ in update_multikeys)))) auto_remove_keys.clear() auto_remove_keys.update(autoremove_keys.difference(keys[:orig_len]) .difference(extra_keys) .difference(extra_key_set)) raise else: extrakeys_list = list(extrakeysdict.items()) extrakeyset_list = list(extrakeysetdict.items()) autoremove_list = list(autoremove_keys.difference(updated_keys) .difference(extrakeysdict.keys()) .difference(extrakeysetdict.keys())) return (tuple(itertools.chain(updated_keys, (k for k,_ in extrakeys_list), (k for k,_ in extrakeyset_list), autoremove_list)), tuple(itertools.chain(updated_values, (v for _,v in extrakeys_list), (v for _,v in extrakeyset_list), [None] * len(autoremove_list)))) def object_updater(keys, values, timestamp): old_version = {} for k, v in zip(keys, values): if v is not None and hasattr(v, 'setkey'): v.setkey(k) if v is not None and hasattr(v, 'kvdb_createtime'): old_version[k] = (getattr(v, 'kvdb_createtime'), getattr(v, 'kvdb_updateversion', 1)) updated_keys, updated_values = updater_with_key(keys, values, timestamp) updated_ref[0] = tuple(updated_keys) new_version = [] for k,v in zip(updated_keys, updated_values): if v is None: new_version.append((timestamp, -1)) elif k in old_version: ov = old_version[k] setattr(v, 'kvdb_createtime', ov[0]) setattr(v, 'kvdb_updateversion', ov[1] + 1) new_version.append((ov[0], ov[1] + 1)) else: setattr(v, 'kvdb_createtime', timestamp) setattr(v, 'kvdb_updateversion', 1) new_version.append((timestamp, 1)) updated_ref[1] = new_version return (updated_keys, updated_values) while True: try: for m in callAPI(self.apiroutine, 'kvstorage', 'updateallwithtime', {'keys': keys + tuple(auto_remove_keys) + \ tuple(extra_keys) + tuple(extra_key_set), 'updater': object_updater}): yield m except _NeedMoreKeysException: pass else: break # Short cut update notification update_keys = self._watchedkeys.intersection(updated_ref[0]) self._updatekeys.update(update_keys) for k,v in zip(updated_ref[0], updated_ref[1]): k = _str(k) if k in update_keys: v = tuple(v) oldv = self._update_version.get(k, (0, -1)) if oldv < v: self._update_version[k] = v for m in self.apiroutine.waitForSend(RetrieveRequestSend()): yield m for m in self._notifier.publish(updated_ref[0], updated_ref[1]): yield m def watchlist(self, requestid = None): """ Return a dictionary whose keys are database keys, and values are lists of request ids. Optionally filtered by request id """ return dict((k,list(v)) for k,v in self._watches.items() if requestid is None or requestid in v) def walk(self, keys, walkerdict, requestid, nostale = False): """ Recursively retrieve keys with customized functions. walkerdict is a dictionary ``key->walker(key, obj, walk, save)``. """ keys = tuple(_str2(k) for k in keys) notify = not self._requests rid = object() self._requests.append((keys, rid, 'walk', dict(walkerdict), requestid)) if notify: for m in self.apiroutine.waitForSend(RetrieveRequestSend()): yield m yield (RetrieveReply.createMatcher(rid),) if hasattr(self.apiroutine.event, 'exception'): raise self.apiroutine.event.exception if nostale and self.apiroutine.event.stale: raise StaleResultException(self.apiroutine.event.result) self.apiroutine.retvalue = self.apiroutine.event.result
class OVSDBManager(Module): ''' Manage Openflow Connections ''' service = True _default_vhostbind = None _default_bridgenames = None def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_conns self.routines.append(self.apiroutine) self.managed_conns = {} self.managed_systemids = {} self.managed_bridges = {} self.managed_routines = [] self.endpoint_conns = {} self.createAPI(api(self.getconnection, self.apiroutine), api(self.waitconnection, self.apiroutine), api(self.getdatapathids, self.apiroutine), api(self.getalldatapathids, self.apiroutine), api(self.getallconnections, self.apiroutine), api(self.getbridges, self.apiroutine), api(self.getbridge, self.apiroutine), api(self.getbridgebyuuid, self.apiroutine), api(self.waitbridge, self.apiroutine), api(self.waitbridgebyuuid, self.apiroutine), api(self.getsystemids, self.apiroutine), api(self.getallsystemids, self.apiroutine), api(self.getconnectionbysystemid, self.apiroutine), api(self.waitconnectionbysystemid, self.apiroutine), api(self.getconnectionsbyendpoint, self.apiroutine), api(self.getconnectionsbyendpointname, self.apiroutine), api(self.getendpoints, self.apiroutine), api(self.getallendpoints, self.apiroutine), api(self.getallbridges, self.apiroutine), api(self.getbridgeinfo, self.apiroutine), api(self.waitbridgeinfo, self.apiroutine) ) self._synchronized = False def _update_bridge(self, connection, protocol, bridge_uuid, vhost): try: method, params = ovsdb.transact('Open_vSwitch', ovsdb.wait('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]], ["datapath_id"], [{"datapath_id": ovsdb.oset()}], False, 5000), ovsdb.select('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]], ["datapath_id","name"])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[1] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error'])) if r['rows']: r0 = r['rows'][0] name = r0['name'] dpid = int(r0['datapath_id'], 16) if self.bridgenames is None or name in self.bridgenames: self.managed_bridges[connection].append((vhost, dpid, name, bridge_uuid)) self.managed_conns[(vhost, dpid)] = connection for m in self.apiroutine.waitForSend(OVSDBBridgeSetup(OVSDBBridgeSetup.UP, dpid, connection.ovsdb_systemid, name, connection, connection.connmark, vhost, bridge_uuid)): yield m except JsonRPCProtocolException: pass def _get_bridges(self, connection, protocol): try: try: vhost = protocol.vhost if not hasattr(connection, 'ovsdb_systemid'): method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Open_vSwitch', [], ['external_ids'])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m result = self.apiroutine.jsonrpc_result[0] system_id = ovsdb.omap_getvalue(result['rows'][0]['external_ids'], 'system-id') connection.ovsdb_systemid = system_id else: system_id = connection.ovsdb_systemid if (vhost, system_id) in self.managed_systemids: oc = self.managed_systemids[(vhost, system_id)] ep = _get_endpoint(oc) econns = self.endpoint_conns.get((vhost, ep)) if econns: try: econns.remove(oc) except ValueError: pass del self.managed_systemids[(vhost, system_id)] self.managed_systemids[(vhost, system_id)] = connection self.managed_bridges[connection] = [] ep = _get_endpoint(connection) self.endpoint_conns.setdefault((vhost, ep), []).append(connection) method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])}) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: # The monitor is already set, cancel it first method, params = ovsdb.monitor_cancel('ovsdb_manager_bridges_monitor') for m in protocol.querywithreply(method, params, connection, self.apiroutine, False): yield m method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])}) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result)) except Exception: for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m raise else: # Process initial bridges init_subprocesses = [self._update_bridge(connection, protocol, buuid, vhost) for buuid in self.apiroutine.jsonrpc_result['Bridge'].keys()] def init_process(): try: with closing(self.apiroutine.executeAll(init_subprocesses, retnames = ())) as g: for m in g: yield m except Exception: for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m raise else: for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m self.apiroutine.subroutine(init_process()) # Wait for notify notification = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark, _ismatch = lambda x: x.params[0] == 'ovsdb_manager_bridges_monitor') conn_down = protocol.statematcher(connection) while True: yield (conn_down, notification) if self.apiroutine.matcher is conn_down: break else: for buuid, v in self.apiroutine.event.params[1]['Bridge'].items(): # If a bridge's name or datapath-id is changed, we remove this bridge and add it again if 'old' in v: # A bridge is deleted bridges = self.managed_bridges[connection] for i in range(0, len(bridges)): if buuid == bridges[i][3]: self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, bridges[i][1], system_id, bridges[i][2], connection, connection.connmark, vhost, bridges[i][3], new_datapath_id = int(v['new']['datapath_id'], 16) if 'new' in v and 'datapath_id' in v['new'] else None)) del self.managed_conns[(vhost, bridges[i][1])] del bridges[i] break if 'new' in v: # A bridge is added self.apiroutine.subroutine(self._update_bridge(connection, protocol, buuid, vhost)) except JsonRPCProtocolException: pass finally: del connection._ovsdb_manager_get_bridges def _manage_existing(self): for m in callAPI(self.apiroutine, "jsonrpcserver", "getconnections", {}): yield m vb = self.vhostbind conns = self.apiroutine.retvalue for c in conns: if vb is None or c.protocol.vhost in vb: if not hasattr(c, '_ovsdb_manager_get_bridges'): c._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(c, c.protocol)) matchers = [OVSDBConnectionSetup.createMatcher(None, c, c.connmark) for c in conns] for m in self.apiroutine.waitForAll(*matchers): yield m self._synchronized = True for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),) def _manage_conns(self): try: self.apiroutine.subroutine(self._manage_existing()) vb = self.vhostbind if vb is not None: conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP, _ismatch = lambda x: x.createby.vhost in vb) conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN, _ismatch = lambda x: x.createby.vhost in vb) else: conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP) conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN) while True: yield (conn_up, conn_down) if self.apiroutine.matcher is conn_up: if not hasattr(self.apiroutine.event.connection, '_ovsdb_manager_get_bridges'): self.apiroutine.event.connection._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(self.apiroutine.event.connection, self.apiroutine.event.createby)) else: e = self.apiroutine.event conn = e.connection bridges = self.managed_bridges.get(conn) if bridges is not None: del self.managed_systemids[(e.createby.vhost, conn.ovsdb_systemid)] del self.managed_bridges[conn] for vhost, dpid, name, buuid in bridges: del self.managed_conns[(vhost, dpid)] self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid, conn.ovsdb_systemid, name, conn, conn.connmark, e.createby.vhost, buuid)) econns = self.endpoint_conns.get(_get_endpoint(conn)) if econns is not None: try: econns.remove(conn) except ValueError: pass finally: for c in self.managed_bridges.keys(): if hasattr(c, '_ovsdb_manager_get_bridges'): c._ovsdb_manager_get_bridges.close() bridges = self.managed_bridges.get(c) if bridges is not None: for vhost, dpid, name, buuid in bridges: del self.managed_conns[(vhost, dpid)] self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid, c.ovsdb_systemid, name, c, c.connmark, c.protocol.vhost, buuid)) def getconnection(self, datapathid, vhost = ''): "Get current connection of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.managed_conns.get((vhost, datapathid)) def waitconnection(self, datapathid, timeout = 30, vhost = ''): "Wait for a datapath connection" for m in self.getconnection(datapathid, vhost): yield m c = self.apiroutine.retvalue if c is None: for m in self.apiroutine.waitWithTimeout(timeout, OVSDBBridgeSetup.createMatcher( state = OVSDBBridgeSetup.UP, datapathid = datapathid, vhost = vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath is not connected') self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getdatapathids(self, vhost = ''): "Get All datapath IDs" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost] def getalldatapathids(self): "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.keys()) def getallconnections(self, vhost = ''): "Get all connections from vhost. If vhost is None, return all connections from any host" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = list(self.managed_bridges.keys()) else: self.apiroutine.retvalue = list(k for k in self.managed_bridges.keys() if k.protocol.vhost == vhost) def getbridges(self, connection): "Get all (dpid, name, _uuid) tuple on this connection" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: self.apiroutine.retvalue = [(dpid, name, buuid) for _, dpid, name, buuid in bridges] else: self.apiroutine.retvalue = None def getallbridges(self, vhost = None): "Get all (dpid, name, _uuid) tuple for all connections, optionally filtered by vhost" for m in self._wait_for_sync(): yield m if vhost is not None: self.apiroutine.retvalue = [(dpid, name, buuid) for c, bridges in self.managed_bridges.items() if c.protocol.vhost == vhost for _, dpid, name, buuid in bridges] else: self.apiroutine.retvalue = [(dpid, name, buuid) for c, bridges in self.managed_bridges.items() for _, dpid, name, buuid in bridges] def getbridge(self, connection, name): "Get datapath ID on this connection with specified name" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: for _, dpid, n, _ in bridges: if n == name: self.apiroutine.retvalue = dpid return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridge(self, connection, name, timeout = 30): "Wait for bridge with specified name appears and return the datapath-id" bnames = self.bridgenames if bnames is not None and name not in bnames: raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear: it is not in the selected bridge names') for m in self.getbridge(connection, name): yield m if self.apiroutine.retvalue is None: bridge_setup = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.UP, None, None, name, connection ) conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection, connection.connmark) for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear') elif self.apiroutine.matcher is conn_down: raise ConnectionResetException('Connection is down before bridge ' + repr(name) + ' appears') else: self.apiroutine.retvalue = self.apiroutine.event.datapathid def getbridgebyuuid(self, connection, uuid): "Get datapath ID of bridge on this connection with specified _uuid" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: for _, dpid, _, buuid in bridges: if buuid == uuid: self.apiroutine.retvalue = dpid return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridgebyuuid(self, connection, uuid, timeout = 30): "Wait for bridge with specified _uuid appears and return the datapath-id" for m in self.getbridgebyuuid(connection, uuid): yield m if self.apiroutine.retvalue is None: bridge_setup = OVSDBBridgeSetup.createMatcher(state = OVSDBBridgeSetup.UP, connection = connection, bridgeuuid = uuid ) conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection, connection.connmark) for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge ' + repr(uuid) + ' does not appear') elif self.apiroutine.matcher is conn_down: raise ConnectionResetException('Connection is down before bridge ' + repr(uuid) + ' appears') else: self.apiroutine.retvalue = self.apiroutine.event.datapathid def getsystemids(self, vhost = ''): "Get All system-ids" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.managed_systemids.keys() if k[0] == vhost] def getallsystemids(self): "Get all system-ids from any vhost. Return (vhost, system-id) pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_systemids.keys()) def getconnectionbysystemid(self, systemid, vhost = ''): for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.managed_systemids.get((vhost, systemid)) def waitconnectionbysystemid(self, systemid, timeout = 30, vhost = ''): "Wait for a connection with specified system-id" for m in self.getconnectionbysystemid(systemid, vhost): yield m c = self.apiroutine.retvalue if c is None: for m in self.apiroutine.waitWithTimeout(timeout, OVSDBConnectionSetup.createMatcher( systemid, None, None, vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath is not connected') self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getconnectionsbyendpoint(self, endpoint, vhost = ''): "Get connection by endpoint address (IP, IPv6 or UNIX socket address)" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint)) def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30): "Get connection by endpoint name (Domain name, IP or IPv6 address)" # Resolve the name if not name: endpoint = '' for m in self.getconnectionbyendpoint(endpoint, vhost): yield m else: request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) # Resolve hostname for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)): yield m for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)): yield m if self.apiroutine.timeout: # Resolve is only allowed through asynchronous resolver #try: # self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST) #except: raise IOError('Resolve hostname timeout: ' + name) else: if hasattr(self.apiroutine.event, 'error'): raise IOError('Cannot resolve hostname: ' + name) resp = self.apiroutine.event.response for r in resp: raddr = r[4] if isinstance(raddr, tuple): # Ignore port endpoint = raddr[0] else: # Unix socket? This should not happen, but in case... endpoint = raddr for m in self.getconnectionsbyendpoint(endpoint, vhost): yield m if self.apiroutine.retvalue is not None: break def getendpoints(self, vhost = ''): "Get all endpoints for vhost" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost] def getallendpoints(self): "Get all endpoints from any vhost. Return (vhost, endpoint) pairs." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.endpoint_conns.keys()) def getbridgeinfo(self, datapathid, vhost = ''): "Get (bridgename, systemid, bridge_uuid) tuple from bridge datapathid" for m in self.getconnection(datapathid, vhost): yield m if self.apiroutine.retvalue is not None: c = self.apiroutine.retvalue bridges = self.managed_bridges.get(c) if bridges is not None: for _, dpid, n, buuid in bridges: if dpid == datapathid: self.apiroutine.retvalue = (n, c.ovsdb_systemid, buuid) return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridgeinfo(self, datapathid, timeout = 30, vhost = ''): "Wait for bridge with datapathid, and return (bridgename, systemid, bridge_uuid) tuple" for m in self.getbridgeinfo(datapathid, vhost): yield m if self.apiroutine.retvalue is None: for m in self.apiroutine.waitWithTimeout(timeout, OVSDBBridgeSetup.createMatcher( OVSDBBridgeSetup.UP, datapathid, None, None, None, None, vhost)): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge 0x%016x does not appear before timeout' % (datapathid,)) e = self.apiroutine.event self.apiroutine.retvalue = (e.name, e.systemid, e.bridgeuuid)
class OVSDBPortManager(Module): ''' Manage Ports from OVSDB Protocol ''' service = True def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_ports self.routines.append(self.apiroutine) self.managed_ports = {} self.managed_ids = {} self.monitor_routines = set() self.ports_uuids = {} self.wait_portnos = {} self.wait_names = {} self.wait_ids = {} self.bridge_datapathid = {} self.createAPI(api(self.getports, self.apiroutine), api(self.getallports, self.apiroutine), api(self.getportbyid, self.apiroutine), api(self.waitportbyid, self.apiroutine), api(self.getportbyname, self.apiroutine), api(self.waitportbyname, self.apiroutine), api(self.getportbyno, self.apiroutine), api(self.waitportbyno, self.apiroutine), api(self.resync, self.apiroutine) ) self._synchronized = False def _get_interface_info(self, connection, protocol, buuid, interface_uuid, port_uuid): try: method, params = ovsdb.transact('Open_vSwitch', ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ofport"], [{"ofport":ovsdb.oset()}], False, 5000), ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ifindex"], [{"ifindex":ovsdb.oset()}], False, 5000), ovsdb.select('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["_uuid", "name", "ifindex", "ofport", "type", "external_ids"])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[1] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[2] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error'])) if not r['rows']: self.apiroutine.retvalue = [] return r0 = r['rows'][0] if r0['ofport'] < 0: # Ignore this port because it is in an error state self.apiroutine.retvalue = [] return r0['_uuid'] = r0['_uuid'][1] r0['ifindex'] = ovsdb.getoptional(r0['ifindex']) r0['external_ids'] = ovsdb.getdict(r0['external_ids']) if buuid not in self.bridge_datapathid: self.apiroutine.retvalue = [] return else: datapath_id = self.bridge_datapathid[buuid] if 'iface-id' in r0['external_ids']: eid = r0['external_ids']['iface-id'] r0['id'] = eid id_ports = self.managed_ids.setdefault((protocol.vhost, eid), []) id_ports.append((datapath_id, r0)) else: r0['id'] = None self.managed_ports.setdefault((protocol.vhost, datapath_id),[]).append((port_uuid, r0)) notify = False if (protocol.vhost, datapath_id, r0['ofport']) in self.wait_portnos: notify = True del self.wait_portnos[(protocol.vhost, datapath_id, r0['ofport'])] if (protocol.vhost, datapath_id, r0['name']) in self.wait_names: notify = True del self.wait_names[(protocol.vhost, datapath_id, r0['name'])] if (protocol.vhost, r0['id']) in self.wait_ids: notify = True del self.wait_ids[(protocol.vhost, r0['id'])] if notify: for m in self.apiroutine.waitForSend(OVSDBPortUpNotification(connection, r0['name'], r0['ofport'], r0['id'], protocol.vhost, datapath_id, port = r0)): yield m self.apiroutine.retvalue = [r0] except JsonRPCProtocolException: self.apiroutine.retvalue = [] def _remove_interface_id(self, connection, protocol, datapath_id, port): eid = port['id'] eid_list = self.managed_ids.get((protocol.vhost, eid)) for i in range(0, len(eid_list)): if eid_list[i][1]['_uuid'] == port['_uuid']: del eid_list[i] break def _remove_interface(self, connection, protocol, datapath_id, interface_uuid, port_uuid): ports = self.managed_ports.get((protocol.vhost, datapath_id)) r = None if ports is not None: for i in range(0, len(ports)): if ports[i][1]['_uuid'] == interface_uuid: r = ports[i][1] if r['id']: self._remove_interface_id(connection, protocol, datapath_id, r) del ports[i] break if not ports: del self.managed_ports[(protocol.vhost, datapath_id)] return r def _remove_all_interface(self, connection, protocol, datapath_id, port_uuid, buuid): ports = self.managed_ports.get((protocol.vhost, datapath_id)) if ports is not None: removed_ports = [r for puuid, r in ports if puuid == port_uuid] not_removed_ports = [(puuid, r) for puuid, r in ports if puuid != port_uuid] ports[:len(not_removed_ports)] = not_removed_ports del ports[len(not_removed_ports):] for r in removed_ports: if r['id']: self._remove_interface_id(connection, protocol, datapath_id, r) if not ports: del self.managed_ports[(protocol.vhost, datapath_id)] return removed_ports if port_uuid in self.ports_uuids and self.ports_uuids[port_uuid] == buuid: del self.ports_uuids[port_uuid] return [] def _update_interfaces(self, connection, protocol, updateinfo, update = True): """ There are several kinds of updates, they may appear together: 1. New bridge created (or from initial updateinfo). We should add all the interfaces to the list. 2. Bridge removed. Remove all the ports. 3. name and datapath_id may be changed. We will consider this as a new bridge created, and an old bridge removed. 4. Bridge ports modification, i.e. add/remove ports. a) Normally a port record is created/deleted together. A port record cannot exist without a bridge containing it. b) It is also possible that a port is removed from one bridge and added to another bridge, in this case the ports do not appear in the updateinfo 5. Port interfaces modification, i.e. add/remove interfaces. The bridge record may not appear in this situation. We must consider these situations carefully and process them in correct order. """ port_update = updateinfo.get('Port', {}) bridge_update = updateinfo.get('Bridge', {}) working_routines = [] def process_bridge(buuid, uo): try: nv = uo['new'] if 'datapath_id' in nv: datapath_id = int(nv['datapath_id'], 16) self.bridge_datapathid[buuid] = datapath_id elif buuid in self.bridge_datapathid: datapath_id = self.bridge_datapathid[buuid] else: # This should not happen, but just in case... for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitbridge', {'connection': connection, 'name': nv['name'], 'timeout': 5}): yield m datapath_id = self.apiroutine.retvalue if 'ports' in nv: nset = set((p for _,p in ovsdb.getlist(nv['ports']))) else: nset = set() if 'old' in uo: ov = uo['old'] if 'ports' in ov: oset = set((p for _,p in ovsdb.getlist(ov['ports']))) else: # new ports are not really added; it is only sent because datapath_id is modified nset = set() oset = set() if 'datapath_id' in ov: old_datapathid = int(ov['datapath_id'], 16) else: old_datapathid = datapath_id else: oset = set() old_datapathid = datapath_id # For every deleted port, remove the interfaces with this port _uuid remove = [] add_routine = [] for puuid in oset - nset: remove += self._remove_all_interface(connection, protocol, old_datapathid, puuid, buuid) # For every port not changed, check if the interfaces are modified; for puuid in oset.intersection(nset): if puuid in port_update: # The port is modified, there should be an 'old' set and 'new' set pu = port_update[puuid] if 'old' in pu: poset = set((p for _,p in ovsdb.getlist(pu['old']['interfaces']))) else: poset = set() if 'new' in pu: pnset = set((p for _,p in ovsdb.getlist(pu['new']['interfaces']))) else: pnset = set() # Remove old interfaces remove += [r for r in (self._remove_interface(connection, protocol, datapath_id, iuuid, puuid) for iuuid in (poset - pnset)) if r is not None] # Prepare to add new interfaces add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for iuuid in (pnset - poset)] # For every port added, add the interfaces def add_port_interfaces(puuid): # If the uuid does not appear in update info, we have no choice but to query interfaces with select # we cannot use data from other bridges; the port may be moved from a bridge which is not tracked try: method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Port', [["_uuid", "==", ovsdb.uuid(puuid)]], ["interfaces"])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException('Error when query interfaces from port ' + repr(puuid) + ': ' + r['error']) if r['rows']: interfaces = ovsdb.getlist(r['rows'][0]['interfaces']) with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for _,iuuid in interfaces])) as g: for m in g: yield m self.apiroutine.retvalue = list(itertools.chain(r[0] for r in self.apiroutine.retvalue)) else: self.apiroutine.retvalue = [] except JsonRPCProtocolException: self.apiroutine.retvalue = [] except ConnectionResetException: self.apiroutine.retvalue = [] for puuid in nset - oset: self.ports_uuids[puuid] = buuid if puuid in port_update and 'new' in port_update[puuid] \ and 'old' not in port_update[puuid]: # Add all the interfaces in 'new' interfaces = ovsdb.getlist(port_update[puuid]['new']['interfaces']) add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for _,iuuid in interfaces] else: add_routine.append(add_port_interfaces(puuid)) # Execute the add_routine try: with closing(self.apiroutine.executeAll(add_routine)) as g: for m in g: yield m except: add = [] raise else: add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue)) finally: if update: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', datapathid = datapath_id, connection = connection, vhost = protocol.vhost, add = add, remove = remove, reason = 'bridgemodify' if 'old' in uo else 'bridgeup' )) except JsonRPCProtocolException: pass except ConnectionResetException: pass except OVSDBBridgeNotAppearException: pass ignore_ports = set() for buuid, uo in bridge_update.items(): # Bridge removals are ignored because we process OVSDBBridgeSetup event instead if 'old' in uo: if 'ports' in uo['old']: oset = set((puuid for _, puuid in ovsdb.getlist(uo['old']['ports']))) ignore_ports.update(oset) if 'new' not in uo: if buuid in self.bridge_datapathid: del self.bridge_datapathid[buuid] if 'new' in uo: # If bridge contains this port is updated, we process the port update totally in bridge, # so we ignore it later if 'ports' in uo['new']: nset = set((puuid for _, puuid in ovsdb.getlist(uo['new']['ports']))) ignore_ports.update(nset) working_routines.append(process_bridge(buuid, uo)) def process_port(buuid, port_uuid, interfaces, remove_ids): if buuid not in self.bridge_datapathid: return datapath_id = self.bridge_datapathid[buuid] ports = self.managed_ports.get((protocol.vhost, datapath_id)) remove = [] if ports is not None: remove = [p for _,p in ports if p['_uuid'] in remove_ids] not_remove = [(_,p) for _,p in ports if p['_uuid'] not in remove_ids] ports[:len(not_remove)] = not_remove del ports[len(not_remove):] if interfaces: try: with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, port_uuid) for iuuid in interfaces])) as g: for m in g: yield m add = list(itertools.chain((r[0] for r in self.apiroutine.retvalue if r[0]))) except Exception: self._logger.warning("Cannot get new port information", exc_info = True) add = [] else: add = [] if update: for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = datapath_id, connection = connection, vhost = protocol.vhost, add = add, remove = remove, reason = 'bridgemodify' )): yield m for puuid, po in port_update.items(): if puuid not in ignore_ports: bridge_id = self.ports_uuids.get(puuid) if bridge_id is not None: datapath_id = self.bridge_datapathid[bridge_id] if datapath_id is not None: # This port is modified if 'new' in po: nset = set((iuuid for _, iuuid in ovsdb.getlist(po['new']['interfaces']))) else: nset = set() if 'old' in po: oset = set((iuuid for _, iuuid in ovsdb.getlist(po['old']['interfaces']))) else: oset = set() working_routines.append(process_port(bridge_id, puuid, nset - oset, oset - nset)) if update: for r in working_routines: self.apiroutine.subroutine(r) else: try: with closing(self.apiroutine.executeAll(working_routines, None, ())) as g: for m in g: yield m finally: self.scheduler.emergesend(OVSDBConnectionPortsSynchronized(connection)) def _get_ports(self, connection, protocol): try: try: method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_port_manager_interfaces_monitor', { 'Bridge':[ovsdb.monitor_request(["name", "datapath_id", "ports"])], 'Port':[ovsdb.monitor_request(["interfaces"])] }) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: # The monitor is already set, cancel it first method2, params2 = ovsdb.monitor_cancel('ovsdb_port_manager_interfaces_monitor') for m in protocol.querywithreply(method2, params2, connection, self.apiroutine, False): yield m for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result)) r = self.apiroutine.jsonrpc_result except: for m in self.apiroutine.waitForSend(OVSDBConnectionPortsSynchronized(connection)): yield m raise # This is the initial state, it should contains all the ids of ports and interfaces self.apiroutine.subroutine(self._update_interfaces(connection, protocol, r, False)) update_matcher = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark, _ismatch = lambda x: x.params[0] == 'ovsdb_port_manager_interfaces_monitor') conn_state = protocol.statematcher(connection) while True: yield (update_matcher, conn_state) if self.apiroutine.matcher is conn_state: break else: self.apiroutine.subroutine(self._update_interfaces(connection, protocol, self.apiroutine.event.params[1], True)) except JsonRPCProtocolException: pass finally: if self.apiroutine.currentroutine in self.monitor_routines: self.monitor_routines.remove(self.apiroutine.currentroutine) def _get_existing_ports(self): for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getallconnections', {'vhost':None}): yield m matchers = [] for c in self.apiroutine.retvalue: self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(c, c.protocol))) matchers.append(OVSDBConnectionPortsSynchronized.createMatcher(c)) for m in self.apiroutine.waitForAll(*matchers): yield m self._synchronized = True for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),) def _manage_ports(self): try: self.apiroutine.subroutine(self._get_existing_ports()) connsetup = OVSDBConnectionSetup.createMatcher() bridgedown = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.DOWN) while True: yield (connsetup, bridgedown) e = self.apiroutine.event if self.apiroutine.matcher is connsetup: self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(e.connection, e.connection.protocol))) else: # Remove ports of the bridge ports = self.managed_ports.get((e.vhost, e.datapathid)) if ports is not None: ports_original = ports ports = [p for _,p in ports] for p in ports: if p['id']: self._remove_interface_id(e.connection, e.connection.protocol, e.datapathid, p) newdpid = getattr(e, 'new_datapath_id', None) buuid = e.bridgeuuid if newdpid is not None: # This bridge changes its datapath id if buuid in self.bridge_datapathid and self.bridge_datapathid[buuid] == e.datapathid: self.bridge_datapathid[buuid] = newdpid def re_add_interfaces(): with closing(self.apiroutine.executeAll( [self._get_interface_info(e.connection, e.connection.protocol, buuid, r['_uuid'], puuid) for puuid, r in ports_original])) as g: for m in g: yield m add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue)) for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = e.datapathid, connection = e.connection, vhost = e.vhost, add = add, remove = [], reason = 'bridgeup' )): yield m self.apiroutine.subroutine(re_add_interfaces()) else: # The ports are removed for puuid, _ in ports_original: if puuid in self.ports_uuids[puuid] and self.ports_uuids[puuid] == buuid: del self.ports_uuids[puuid] del self.managed_ports[(e.vhost, e.datapathid)] self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', datapathid = e.datapathid, connection = e.connection, vhost = e.vhost, add = [], remove = ports, reason = 'bridgedown' )) finally: for r in list(self.monitor_routines): r.close() self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized')) def getports(self, datapathid, vhost = ''): "Return all ports of a specifed datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [p for _,p in self.managed_ports.get((vhost, datapathid), [])] def getallports(self, vhost = None): "Return all (datapathid, port, vhost) tuples, optionally filterd by vhost" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() for _,p in v] else: self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() if vh == vhost for _,p in v] def getportbyno(self, datapathid, portno, vhost = ''): "Return port with specified portno" portno &= 0xffff for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost) def _getportbyno(self, datapathid, portno, vhost = ''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for _, p in ports: if p['ofport'] == portno: return p return None def waitportbyno(self, datapathid, portno, timeout = 30, vhost = ''): "Wait for port with specified portno" portno &= 0xffff for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyno(datapathid, portno, vhost) if p is not None: self.apiroutine.retvalue = p else: try: self.wait_portnos[(vhost, datapathid, portno)] = \ self.wait_portnos.get((vhost, datapathid, portno),0) + 1 yield (OVSDBPortUpNotification.createMatcher(None, None, portno, None, vhost, datapathid),) except: v = self.wait_portnos.get((vhost, datapathid, portno)) if v is not None: if v <= 1: del self.wait_portnos[(vhost, datapathid, portno)] else: self.wait_portnos[(vhost, datapathid, portno)] = v - 1 raise else: self.apiroutine.retvalue = self.apiroutine.event.port for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException('Port ' + repr(portno) + ' does not appear before timeout') def getportbyname(self, datapathid, name, vhost = ''): "Return port with specified name" if isinstance(name, bytes): name = name.decode('utf-8') for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost) def _getportbyname(self, datapathid, name, vhost = ''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for _, p in ports: if p['name'] == name: return p return None def waitportbyname(self, datapathid, name, timeout = 30, vhost = ''): "Wait for port with specified name" for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyname(datapathid, name, vhost) if p is not None: self.apiroutine.retvalue = p else: try: self.wait_names[(vhost, datapathid, name)] = \ self.wait_portnos.get((vhost, datapathid, name) ,0) + 1 yield (OVSDBPortUpNotification.createMatcher(None, name, None, None, vhost, datapathid),) except: v = self.wait_names.get((vhost, datapathid, name)) if v is not None: if v <= 1: del self.wait_names[(vhost, datapathid, name)] else: self.wait_names[(vhost, datapathid, name)] = v - 1 raise else: self.apiroutine.retvalue = self.apiroutine.event.port for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException('Port ' + repr(name) + ' does not appear before timeout') def getportbyid(self, id, vhost = ''): "Return port with the specified id. The return value is a pair: (datapath_id, port)" for m in self._wait_for_sync(): yield m self.apiroutine = self._getportbyid(id, vhost) def _getportbyid(self, id, vhost = ''): ports = self.managed_ids.get((vhost, id)) if ports: return ports[0] else: return None def waitportbyid(self, id, timeout = 30, vhost = ''): "Wait for port with the specified id. The return value is a pair (datapath_id, port)" for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyid(id, vhost) if p is None: try: self.wait_ids[(vhost, id)] = self.wait_ids.get((vhost, id), 0) + 1 yield (OVSDBPortUpNotification.createMatcher(None, None, None, id, vhost),) except: v = self.wait_ids.get((vhost, id)) if v is not None: if v <= 1: del self.wait_ids[(vhost, id)] else: self.wait_ids[(vhost, id)] = v - 1 raise else: self.apiroutine.retvalue = (self.apiroutine.event.datapathid, self.apiroutine.event.port) else: self.apiroutine.retvalue = p for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException('Port ' + repr(id) + ' does not appear before timeout') def resync(self, datapathid, vhost = ''): ''' Resync with current ports ''' # Sometimes when the OVSDB connection is very busy, monitor message may be dropped. # We must deal with this and recover from it # Save current manged_ports if (vhost, datapathid) not in self.managed_ports: self.apiroutine.retvalue = None return else: for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getconnection', {'datapathid': datapathid, 'vhost':vhost}): yield m c = self.apiroutine.retvalue if c is not None: # For now, we restart the connection... for m in c.reconnect(False): yield m for m in self.apiroutine.waitWithTimeout(0.1): yield m for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitconnection', {'datapathid': datapathid, 'vhost': vhost}): yield m self.apiroutine.retvalue = None