class TestModule(Module): _default_serverlist = ['tcp://localhost:3181/','tcp://localhost:3182/','tcp://localhost:3183/'] def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self.main self.routines.append(self.apiroutine) def main(self): clients = [ZooKeeperClient(self.apiroutine, self.serverlist) for _ in range(0,10)] for c in clients: c.start() def test_loop(number): maindir = ('vlcptest_' + str(number)).encode('utf-8') client = clients[number % len(clients)] for _ in range(0, 100): for m in client.requests([zk.multi( zk.multi_create(maindir, b'test'), zk.multi_create(maindir + b'/subtest', 'test2') ), zk.getchildren2(maindir, True)], self.apiroutine): yield m for m in client.requests([zk.multi( zk.multi_delete(maindir + b'/subtest'), zk.multi_delete(maindir)), zk.getchildren2(maindir, True)], self.apiroutine): yield m from time import time starttime = time() for m in self.apiroutine.executeAll([test_loop(i) for i in range(0, 100)]): yield m print('10000 loops in %r seconds, with %d connections' % (time() - starttime, len(clients))) for c in clients: for m in c.shutdown(): yield m
class OVSDBManager(Module): ''' Manage Openflow Connections ''' service = True # Bind to JsonRPCServer vHosts. If not None, should be a list of vHost names e.g. ``['']`` _default_vhostbind = None # Only acquire information from bridges with this names _default_bridgenames = None def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_conns self.routines.append(self.apiroutine) self.managed_conns = {} self.managed_systemids = {} self.managed_bridges = {} self.managed_routines = [] self.endpoint_conns = {} self.createAPI(api(self.getconnection, self.apiroutine), api(self.waitconnection, self.apiroutine), api(self.getdatapathids, self.apiroutine), api(self.getalldatapathids, self.apiroutine), api(self.getallconnections, self.apiroutine), api(self.getbridges, self.apiroutine), api(self.getbridge, self.apiroutine), api(self.getbridgebyuuid, self.apiroutine), api(self.waitbridge, self.apiroutine), api(self.waitbridgebyuuid, self.apiroutine), api(self.getsystemids, self.apiroutine), api(self.getallsystemids, self.apiroutine), api(self.getconnectionbysystemid, self.apiroutine), api(self.waitconnectionbysystemid, self.apiroutine), api(self.getconnectionsbyendpoint, self.apiroutine), api(self.getconnectionsbyendpointname, self.apiroutine), api(self.getendpoints, self.apiroutine), api(self.getallendpoints, self.apiroutine), api(self.getallbridges, self.apiroutine), api(self.getbridgeinfo, self.apiroutine), api(self.waitbridgeinfo, self.apiroutine)) self._synchronized = False def _update_bridge(self, connection, protocol, bridge_uuid, vhost): try: method, params = ovsdb.transact( 'Open_vSwitch', ovsdb.wait( 'Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]], ["datapath_id"], [{ "datapath_id": ovsdb.oset() }], False, 5000), ovsdb.select( 'Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]], ["datapath_id", "name"])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException( 'Error while acquiring datapath-id: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[1] if 'error' in r: raise JsonRPCErrorResultException( 'Error while acquiring datapath-id: ' + repr(r['error'])) if r['rows']: r0 = r['rows'][0] name = r0['name'] dpid = int(r0['datapath_id'], 16) if self.bridgenames is None or name in self.bridgenames: self.managed_bridges[connection].append( (vhost, dpid, name, bridge_uuid)) self.managed_conns[(vhost, dpid)] = connection for m in self.apiroutine.waitForSend( OVSDBBridgeSetup(OVSDBBridgeSetup.UP, dpid, connection.ovsdb_systemid, name, connection, connection.connmark, vhost, bridge_uuid)): yield m except JsonRPCProtocolException: pass def _get_bridges(self, connection, protocol): try: try: vhost = protocol.vhost if not hasattr(connection, 'ovsdb_systemid'): method, params = ovsdb.transact( 'Open_vSwitch', ovsdb.select('Open_vSwitch', [], ['external_ids'])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m result = self.apiroutine.jsonrpc_result[0] system_id = ovsdb.omap_getvalue( result['rows'][0]['external_ids'], 'system-id') connection.ovsdb_systemid = system_id else: system_id = connection.ovsdb_systemid if (vhost, system_id) in self.managed_systemids: oc = self.managed_systemids[(vhost, system_id)] ep = _get_endpoint(oc) econns = self.endpoint_conns.get((vhost, ep)) if econns: try: econns.remove(oc) except ValueError: pass del self.managed_systemids[(vhost, system_id)] self.managed_systemids[(vhost, system_id)] = connection self.managed_bridges[connection] = [] ep = _get_endpoint(connection) self.endpoint_conns.setdefault((vhost, ep), []).append(connection) method, params = ovsdb.monitor( 'Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge': ovsdb.monitor_request(['name', 'datapath_id'])}) try: for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m except JsonRPCErrorResultException: # The monitor is already set, cancel it first method, params = ovsdb.monitor_cancel( 'ovsdb_manager_bridges_monitor') for m in protocol.querywithreply(method, params, connection, self.apiroutine, False): yield m method, params = ovsdb.monitor( 'Open_vSwitch', 'ovsdb_manager_bridges_monitor', { 'Bridge': ovsdb.monitor_request(['name', 'datapath_id']) }) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m except Exception: for m in self.apiroutine.waitForSend( OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m raise else: # Process initial bridges init_subprocesses = [] if self.apiroutine.jsonrpc_result and 'Bridge' in self.apiroutine.jsonrpc_result: init_subprocesses = [ self._update_bridge(connection, protocol, buuid, vhost) for buuid in self.apiroutine.jsonrpc_result['Bridge'].keys() ] def init_process(): try: with closing( self.apiroutine.executeAll(init_subprocesses, retnames=())) as g: for m in g: yield m except Exception: for m in self.apiroutine.waitForSend( OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m raise else: for m in self.apiroutine.waitForSend( OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m self.apiroutine.subroutine(init_process()) # Wait for notify notification = JsonRPCNotificationEvent.createMatcher( 'update', connection, connection.connmark, _ismatch=lambda x: x.params[ 0] == 'ovsdb_manager_bridges_monitor') conn_down = protocol.statematcher(connection) while True: yield (conn_down, notification) if self.apiroutine.matcher is conn_down: break else: for buuid, v in self.apiroutine.event.params[1][ 'Bridge'].items(): # If a bridge's name or datapath-id is changed, we remove this bridge and add it again if 'old' in v: # A bridge is deleted bridges = self.managed_bridges[connection] for i in range(0, len(bridges)): if buuid == bridges[i][3]: self.scheduler.emergesend( OVSDBBridgeSetup( OVSDBBridgeSetup.DOWN, bridges[i][1], system_id, bridges[i][2], connection, connection.connmark, vhost, bridges[i][3], new_datapath_id=int( v['new']['datapath_id'], 16) if 'new' in v and 'datapath_id' in v['new'] else None)) del self.managed_conns[(vhost, bridges[i][1])] del bridges[i] break if 'new' in v: # A bridge is added self.apiroutine.subroutine( self._update_bridge(connection, protocol, buuid, vhost)) except JsonRPCProtocolException: pass finally: del connection._ovsdb_manager_get_bridges def _manage_existing(self): for m in callAPI(self.apiroutine, "jsonrpcserver", "getconnections", {}): yield m vb = self.vhostbind conns = self.apiroutine.retvalue for c in conns: if vb is None or c.protocol.vhost in vb: if not hasattr(c, '_ovsdb_manager_get_bridges'): c._ovsdb_manager_get_bridges = self.apiroutine.subroutine( self._get_bridges(c, c.protocol)) matchers = [ OVSDBConnectionSetup.createMatcher(None, c, c.connmark) for c in conns if vb is None or c.protocol.vhost in vb ] for m in self.apiroutine.waitForAll(*matchers): yield m self._synchronized = True for m in self.apiroutine.waitForSend( ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'), ) def _manage_conns(self): try: self.apiroutine.subroutine(self._manage_existing()) vb = self.vhostbind if vb is not None: conn_up = JsonRPCConnectionStateEvent.createMatcher( state=JsonRPCConnectionStateEvent.CONNECTION_UP, _ismatch=lambda x: x.createby.vhost in vb) conn_down = JsonRPCConnectionStateEvent.createMatcher( state=JsonRPCConnectionStateEvent.CONNECTION_DOWN, _ismatch=lambda x: x.createby.vhost in vb) else: conn_up = JsonRPCConnectionStateEvent.createMatcher( state=JsonRPCConnectionStateEvent.CONNECTION_UP) conn_down = JsonRPCConnectionStateEvent.createMatcher( state=JsonRPCConnectionStateEvent.CONNECTION_DOWN) while True: yield (conn_up, conn_down) if self.apiroutine.matcher is conn_up: if not hasattr(self.apiroutine.event.connection, '_ovsdb_manager_get_bridges'): self.apiroutine.event.connection._ovsdb_manager_get_bridges = self.apiroutine.subroutine( self._get_bridges(self.apiroutine.event.connection, self.apiroutine.event.createby)) else: e = self.apiroutine.event conn = e.connection bridges = self.managed_bridges.get(conn) if bridges is not None: del self.managed_systemids[(e.createby.vhost, conn.ovsdb_systemid)] del self.managed_bridges[conn] for vhost, dpid, name, buuid in bridges: del self.managed_conns[(vhost, dpid)] self.scheduler.emergesend( OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid, conn.ovsdb_systemid, name, conn, conn.connmark, e.createby.vhost, buuid)) econns = self.endpoint_conns.get(_get_endpoint(conn)) if econns is not None: try: econns.remove(conn) except ValueError: pass finally: for c in self.managed_bridges.keys(): if hasattr(c, '_ovsdb_manager_get_bridges'): c._ovsdb_manager_get_bridges.close() bridges = self.managed_bridges.get(c) if bridges is not None: for vhost, dpid, name, buuid in bridges: del self.managed_conns[(vhost, dpid)] self.scheduler.emergesend( OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid, c.ovsdb_systemid, name, c, c.connmark, c.protocol.vhost, buuid)) def getconnection(self, datapathid, vhost=''): "Get current connection of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.managed_conns.get((vhost, datapathid)) def waitconnection(self, datapathid, timeout=30, vhost=''): "Wait for a datapath connection" for m in self.getconnection(datapathid, vhost): yield m c = self.apiroutine.retvalue if c is None: for m in self.apiroutine.waitWithTimeout( timeout, OVSDBBridgeSetup.createMatcher(state=OVSDBBridgeSetup.UP, datapathid=datapathid, vhost=vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath is not connected') self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getdatapathids(self, vhost=''): "Get All datapath IDs" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ k[1] for k in self.managed_conns.keys() if k[0] == vhost ] def getalldatapathids(self): "Get all datapath IDs from any vhost. Return ``(vhost, datapathid)`` pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.keys()) def getallconnections(self, vhost=''): "Get all connections from vhost. If vhost is None, return all connections from any host" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = list(self.managed_bridges.keys()) else: self.apiroutine.retvalue = list( k for k in self.managed_bridges.keys() if k.protocol.vhost == vhost) def getbridges(self, connection): "Get all ``(dpid, name, _uuid)`` tuple on this connection" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: self.apiroutine.retvalue = [(dpid, name, buuid) for _, dpid, name, buuid in bridges] else: self.apiroutine.retvalue = None def getallbridges(self, vhost=None): "Get all ``(dpid, name, _uuid)`` tuple for all connections, optionally filtered by vhost" for m in self._wait_for_sync(): yield m if vhost is not None: self.apiroutine.retvalue = [ (dpid, name, buuid) for c, bridges in self.managed_bridges.items() if c.protocol.vhost == vhost for _, dpid, name, buuid in bridges ] else: self.apiroutine.retvalue = [ (dpid, name, buuid) for c, bridges in self.managed_bridges.items() for _, dpid, name, buuid in bridges ] def getbridge(self, connection, name): "Get datapath ID on this connection with specified name" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: for _, dpid, n, _ in bridges: if n == name: self.apiroutine.retvalue = dpid return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridge(self, connection, name, timeout=30): "Wait for bridge with specified name appears and return the datapath-id" bnames = self.bridgenames if bnames is not None and name not in bnames: raise OVSDBBridgeNotAppearException( 'Bridge ' + repr(name) + ' does not appear: it is not in the selected bridge names') for m in self.getbridge(connection, name): yield m if self.apiroutine.retvalue is None: bridge_setup = OVSDBBridgeSetup.createMatcher( OVSDBBridgeSetup.UP, None, None, name, connection) conn_down = JsonRPCConnectionStateEvent.createMatcher( JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection, connection.connmark) for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear') elif self.apiroutine.matcher is conn_down: raise ConnectionResetException( 'Connection is down before bridge ' + repr(name) + ' appears') else: self.apiroutine.retvalue = self.apiroutine.event.datapathid def getbridgebyuuid(self, connection, uuid): "Get datapath ID of bridge on this connection with specified _uuid" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: for _, dpid, _, buuid in bridges: if buuid == uuid: self.apiroutine.retvalue = dpid return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridgebyuuid(self, connection, uuid, timeout=30): "Wait for bridge with specified _uuid appears and return the datapath-id" for m in self.getbridgebyuuid(connection, uuid): yield m if self.apiroutine.retvalue is None: bridge_setup = OVSDBBridgeSetup.createMatcher( state=OVSDBBridgeSetup.UP, connection=connection, bridgeuuid=uuid) conn_down = JsonRPCConnectionStateEvent.createMatcher( JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection, connection.connmark) for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge ' + repr(uuid) + ' does not appear') elif self.apiroutine.matcher is conn_down: raise ConnectionResetException( 'Connection is down before bridge ' + repr(uuid) + ' appears') else: self.apiroutine.retvalue = self.apiroutine.event.datapathid def getsystemids(self, vhost=''): "Get All system-ids" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ k[1] for k in self.managed_systemids.keys() if k[0] == vhost ] def getallsystemids(self): "Get all system-ids from any vhost. Return ``(vhost, system-id)`` pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_systemids.keys()) def getconnectionbysystemid(self, systemid, vhost=''): for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.managed_systemids.get( (vhost, systemid)) def waitconnectionbysystemid(self, systemid, timeout=30, vhost=''): "Wait for a connection with specified system-id" for m in self.getconnectionbysystemid(systemid, vhost): yield m c = self.apiroutine.retvalue if c is None: for m in self.apiroutine.waitWithTimeout( timeout, OVSDBConnectionSetup.createMatcher(systemid, None, None, vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath is not connected') self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getconnectionsbyendpoint(self, endpoint, vhost=''): "Get connection by endpoint address (IP, IPv6 or UNIX socket address)" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint)) def getconnectionsbyendpointname(self, name, vhost='', timeout=30): "Get connection by endpoint name (Domain name, IP or IPv6 address)" # Resolve the name if not name: endpoint = '' for m in self.getconnectionbyendpoint(endpoint, vhost): yield m else: request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) # Resolve hostname for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)): yield m for m in self.apiroutine.waitWithTimeout( timeout, ResolveResponseEvent.createMatcher(request)): yield m if self.apiroutine.timeout: # Resolve is only allowed through asynchronous resolver #try: # self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST) #except: raise IOError('Resolve hostname timeout: ' + name) else: if hasattr(self.apiroutine.event, 'error'): raise IOError('Cannot resolve hostname: ' + name) resp = self.apiroutine.event.response for r in resp: raddr = r[4] if isinstance(raddr, tuple): # Ignore port endpoint = raddr[0] else: # Unix socket? This should not happen, but in case... endpoint = raddr for m in self.getconnectionsbyendpoint(endpoint, vhost): yield m if self.apiroutine.retvalue is not None: break def getendpoints(self, vhost=''): "Get all endpoints for vhost" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ k[1] for k in self.endpoint_conns if k[0] == vhost ] def getallendpoints(self): "Get all endpoints from any vhost. Return ``(vhost, endpoint)`` pairs." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.endpoint_conns.keys()) def getbridgeinfo(self, datapathid, vhost=''): "Get ``(bridgename, systemid, bridge_uuid)`` tuple from bridge datapathid" for m in self.getconnection(datapathid, vhost): yield m if self.apiroutine.retvalue is not None: c = self.apiroutine.retvalue bridges = self.managed_bridges.get(c) if bridges is not None: for _, dpid, n, buuid in bridges: if dpid == datapathid: self.apiroutine.retvalue = (n, c.ovsdb_systemid, buuid) return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridgeinfo(self, datapathid, timeout=30, vhost=''): "Wait for bridge with datapathid, and return ``(bridgename, systemid, bridge_uuid)`` tuple" for m in self.getbridgeinfo(datapathid, vhost): yield m if self.apiroutine.retvalue is None: for m in self.apiroutine.waitWithTimeout( timeout, OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.UP, datapathid, None, None, None, None, vhost)): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException( 'Bridge 0x%016x does not appear before timeout' % (datapathid, )) e = self.apiroutine.event self.apiroutine.retvalue = (e.name, e.systemid, e.bridgeuuid)
class OpenflowPortManager(Module): ''' Manage Ports from Openflow Protocol ''' service = True def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_ports self.routines.append(self.apiroutine) self.managed_ports = {} self.createAPI(api(self.getports, self.apiroutine), api(self.getallports, self.apiroutine), api(self.getportbyno, self.apiroutine), api(self.waitportbyno, self.apiroutine), api(self.getportbyname, self.apiroutine), api(self.waitportbyname, self.apiroutine), api(self.resync, self.apiroutine) ) self._synchronized = False def _get_ports(self, connection, protocol, onup = False, update = True): ofdef = connection.openflowdef dpid = connection.openflow_datapathid vhost = connection.protocol.vhost add = [] try: if hasattr(ofdef, 'ofp_multipart_request'): # Openflow 1.3, use ofp_multipart_request to get ports for m in protocol.querymultipart(ofdef.ofp_multipart_request(type=ofdef.OFPMP_PORT_DESC), connection, self.apiroutine): yield m ports = self.managed_ports.setdefault((vhost, dpid), {}) for msg in self.apiroutine.openflow_reply: for p in msg.ports: add.append(p) ports[p.port_no] = p else: # Openflow 1.0, use features_request if onup: # Use the features_reply on connection setup reply = connection.openflow_featuresreply else: request = ofdef.ofp_msg() request.header.type = ofdef.OFPT_FEATURES_REQUEST for m in protocol.querywithreply(request): yield m reply = self.apiroutine.retvalue ports = self.managed_ports.setdefault((vhost, dpid), {}) for p in reply.ports: add.append(p) ports[p.port_no] = p if update: for m in self.apiroutine.waitForSend(OpenflowPortSynchronized(connection)): yield m for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = connection.openflow_datapathid, connection = connection, vhost = protocol.vhost, add = add, remove = [], reason = 'connected')): yield m except ConnectionResetException: pass except OpenflowProtocolException: pass def _get_existing_ports(self): for m in callAPI(self.apiroutine, 'openflowmanager', 'getallconnections', {'vhost':None}): yield m with closing(self.apiroutine.executeAll([self._get_ports(c, c.protocol, False, False) for c in self.apiroutine.retvalue if c.openflow_auxiliaryid == 0])) as g: for m in g: yield m self._synchronized = True for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),) def _manage_ports(self): try: self.apiroutine.subroutine(self._get_existing_ports()) conn_update = ModuleNotification.createMatcher('openflowmanager', 'update') port_status = OpenflowAsyncMessageEvent.createMatcher(of13.OFPT_PORT_STATUS, None, 0) while True: yield (conn_update, port_status) if self.apiroutine.matcher is port_status: e = self.apiroutine.event m = e.message c = e.connection if (c.protocol.vhost, c.openflow_datapathid) in self.managed_ports: if m.reason == c.openflowdef.OFPPR_ADD: # A new port is added self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no] = m.desc self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', datapathid = c.openflow_datapathid, connection = c, vhost = c.protocol.vhost, add = [m.desc], remove = [], reason = 'add')) elif m.reason == c.openflowdef.OFPPR_DELETE: try: del self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no] self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', datapathid = c.openflow_datapathid, connection = c, vhost = c.protocol.vhost, add = [], remove = [m.desc], reason = 'delete')) except KeyError: pass elif m.reason == c.openflowdef.OFPPR_MODIFY: try: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'modified', datapathid = c.openflow_datapathid, connection = c, vhost = c.protocol.vhost, old = self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no], new = m.desc, reason = 'modified')) except KeyError: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', datapathid = c.openflow_datapathid, connection = c, vhost = c.protocol.vhost, add = [m.desc], remove = [], reason = 'add')) self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no] = m.desc else: e = self.apiroutine.event for c in e.remove: if c.openflow_auxiliaryid == 0 and (c.protocol.vhost, c.openflow_datapathid) in self.managed_ports: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', datapathid = c.openflow_datapathid, connection = c, vhost = c.protocol.vhost, add = [], remove = list(self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)].values()), reason = 'disconnected')) del self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)] for c in e.add: if c.openflow_auxiliaryid == 0: self.apiroutine.subroutine(self._get_ports(c, c.protocol, True, True)) finally: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized')) def getports(self, datapathid, vhost = ''): "Return all ports of a specifed datapath" for m in self._wait_for_sync(): yield m r = self.managed_ports.get((vhost, datapathid)) if r is None: self.apiroutine.retvalue = None else: self.apiroutine.retvalue = list(r.values()) def getallports(self, vhost = None): "Return all (datapathid, port, vhost) tuples, optionally filterd by vhost" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() for p in v.values()] else: self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() if vh == vhost for p in v.values()] def getportbyno(self, datapathid, portno, vhost = ''): "Return port with specified portno" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost) def _getportbyno(self, datapathid, portno, vhost = ''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: return ports.get(portno) def waitportbyno(self, datapathid, portno, timeout = 30, vhost = ''): for m in self._wait_for_sync(): yield m def waitinner(): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: for m in callAPI(self.apiroutine, 'openflowmanager', 'waitconnection', {'datapathid': datapathid, 'vhost':vhost, 'timeout': timeout}): yield m c = self.apiroutine.retvalue ports = self.managed_ports.get((vhost, datapathid)) if ports is None: yield (OpenflowPortSynchronized.createMatcher(c),) ports = self.managed_ports.get((vhost, datapathid)) if ports is None: raise ConnectionResetException('Datapath %016x is not connected' % datapathid) if portno not in ports: yield (OpenflowAsyncMessageEvent.createMatcher(of13.OFPT_PORT_STATUS, datapathid, 0, _ismatch = lambda x: x.message.desc.port_no == portno),) self.apiroutine.retvalue = self.apiroutine.event.message.desc else: self.apiroutine.retvalue = ports[portno] for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OpenflowPortNotAppearException('Port %d does not appear on datapath %016x' % (portno, datapathid)) def getportbyname(self, datapathid, name, vhost = ''): "Return port with specified portno" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost) def _getportbyname(self, datapathid, name, vhost = ''): if not isinstance(name, bytes): name = _bytes(name) ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for p in ports.values(): if p.name == name: return p return None def waitportbyname(self, datapathid, name, timeout = 30, vhost = ''): for m in self._wait_for_sync(): yield m if not isinstance(name, bytes): name = _bytes(name) def waitinner(): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: for m in callAPI(self.apiroutine, 'openflowmanager', 'waitconnection', {'datapathid': datapathid, 'vhost':vhost, 'timeout': timeout}): yield m c = self.apiroutine.retvalue ports = self.managed_ports.get((vhost, datapathid)) if ports is None: yield (OpenflowPortSynchronized.createMatcher(c),) ports = self.managed_ports.get((vhost, datapathid)) if ports is None: raise ConnectionResetException('Datapath %016x is not connected' % datapathid) for p in ports.values(): if p.name == name: self.apiroutine.retvalue = p return yield (OpenflowAsyncMessageEvent.createMatcher(of13.OFPT_PORT_STATUS, datapathid, 0, _ismatch = lambda x: x.message.desc.name == name),) self.apiroutine.retvalue = self.apiroutine.event.message.desc for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OpenflowPortNotAppearException('Port %r does not appear on datapath %016x' % (name, datapathid)) def resync(self, datapathid, vhost = ''): ''' Resync with current ports ''' # Sometimes when the OpenFlow connection is very busy, PORT_STATUS message may be dropped. # We must deal with this and recover from it # Save current manged_ports if (vhost, datapathid) not in self.managed_ports: self.apiroutine.retvalue = None return else: last_ports = set(self.managed_ports[(vhost, datapathid)].keys()) add = set() remove = set() ports = {} for _ in range(0, 10): for m in callAPI(self.apiroutine, 'openflowmanager', 'getconnection', {'datapathid': datapathid, 'vhost':vhost}): yield m c = self.apiroutine.retvalue if c is None: # Disconnected, will automatically resync when reconnected self.apiroutine.retvalue = None return ofdef = c.openflowdef protocol = c.protocol try: if hasattr(ofdef, 'ofp_multipart_request'): # Openflow 1.3, use ofp_multipart_request to get ports for m in protocol.querymultipart(ofdef.ofp_multipart_request(type=ofdef.OFPMP_PORT_DESC), c, self.apiroutine): yield m for msg in self.apiroutine.openflow_reply: for p in msg.ports: ports[p.port_no] = p else: # Openflow 1.0, use features_request request = ofdef.ofp_msg() request.header.type = ofdef.OFPT_FEATURES_REQUEST for m in protocol.querywithreply(request): yield m reply = self.apiroutine.retvalue for p in reply.ports: ports[p.port_no] = p except ConnectionResetException: break except OpenflowProtocolException: break else: if (vhost, datapathid) not in self.managed_ports: self.apiroutine.retvalue = None return current_ports = set(self.managed_ports[(vhost, datapathid)]) # If a port is already removed remove.intersection_update(current_ports) # If a port is already added add.difference_update(current_ports) # If a port is not acquired, we do not add it acquired_keys = set(ports.keys()) add.difference_update(acquired_keys) # Add and remove previous added/removed ports current_ports.difference_update(remove) current_ports.update(add) # If there are changed ports, the changed ports may or may not appear in the acquired port list # We only deal with following situations: # 1. If both lack ports, we add them # 2. If both have additional ports, we remote them to_add = acquired_keys.difference(current_ports.union(last_ports)) to_remove = current_ports.intersection(last_ports).difference(acquired_keys) if not to_add and not to_remove and current_ports == last_ports: break else: add.update(to_add) remove.update(to_remove) current_ports.update(to_add) current_ports.difference_update(to_remove) last_ports = current_ports # Actual add and remove mports = self.managed_ports[(vhost, datapathid)] add_ports = [] remove_ports = [] for k in add: if k not in mports: add_ports.append(ports[k]) mports[k] = ports[k] for k in remove: try: oldport = mports.pop(k) except KeyError: pass else: remove_ports.append(oldport) for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = datapathid, connection = c, vhost = vhost, add = add_ports, remove = remove_ports, reason = 'resync')): yield m self.apiroutine.retvalue = None
class OpenflowPortManager(Module): ''' Manage Ports from Openflow Protocol ''' service = True def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_ports self.routines.append(self.apiroutine) self.managed_ports = {} self.createAPI(api(self.getports, self.apiroutine), api(self.getallports, self.apiroutine), api(self.getportbyno, self.apiroutine), api(self.waitportbyno, self.apiroutine), api(self.getportbyname, self.apiroutine), api(self.waitportbyname, self.apiroutine), api(self.resync, self.apiroutine)) self._synchronized = False def _get_ports(self, connection, protocol, onup=False, update=True): ofdef = connection.openflowdef dpid = connection.openflow_datapathid vhost = connection.protocol.vhost add = [] try: if hasattr(ofdef, 'ofp_multipart_request'): # Openflow 1.3, use ofp_multipart_request to get ports for m in protocol.querymultipart( ofdef.ofp_multipart_request( type=ofdef.OFPMP_PORT_DESC), connection, self.apiroutine): yield m ports = self.managed_ports.setdefault((vhost, dpid), {}) for msg in self.apiroutine.openflow_reply: for p in msg.ports: add.append(p) ports[p.port_no] = p else: # Openflow 1.0, use features_request if onup: # Use the features_reply on connection setup reply = connection.openflow_featuresreply else: request = ofdef.ofp_msg() request.header.type = ofdef.OFPT_FEATURES_REQUEST for m in protocol.querywithreply(request): yield m reply = self.apiroutine.retvalue ports = self.managed_ports.setdefault((vhost, dpid), {}) for p in reply.ports: add.append(p) ports[p.port_no] = p if update: for m in self.apiroutine.waitForSend( OpenflowPortSynchronized(connection)): yield m self._logger.info( "Openflow port synchronized on connection %r", connection) for m in self.apiroutine.waitForSend( ModuleNotification( self.getServiceName(), 'update', datapathid=connection.openflow_datapathid, connection=connection, vhost=protocol.vhost, add=add, remove=[], reason='connected')): yield m except ConnectionResetException: pass except OpenflowProtocolException: pass def _get_existing_ports(self): for m in callAPI(self.apiroutine, 'openflowmanager', 'getallconnections', {'vhost': None}): yield m with closing( self.apiroutine.executeAll([ self._get_ports(c, c.protocol, False, False) for c in self.apiroutine.retvalue if c.openflow_auxiliaryid == 0 ])) as g: for m in g: yield m self._synchronized = True self._logger.info("Openflow ports synchronized") for m in self.apiroutine.waitForSend( ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'), ) def _manage_ports(self): try: self.apiroutine.subroutine(self._get_existing_ports()) conn_update = ModuleNotification.createMatcher( 'openflowmanager', 'update') port_status = OpenflowAsyncMessageEvent.createMatcher( of13.OFPT_PORT_STATUS, None, 0) while True: yield (conn_update, port_status) if self.apiroutine.matcher is port_status: e = self.apiroutine.event m = e.message c = e.connection if (c.protocol.vhost, c.openflow_datapathid) in self.managed_ports: if m.reason == c.openflowdef.OFPPR_ADD: # A new port is added self.managed_ports[(c.protocol.vhost, c.openflow_datapathid )][m.desc.port_no] = m.desc self.scheduler.emergesend( ModuleNotification( self.getServiceName(), 'update', datapathid=c.openflow_datapathid, connection=c, vhost=c.protocol.vhost, add=[m.desc], remove=[], reason='add')) elif m.reason == c.openflowdef.OFPPR_DELETE: try: del self.managed_ports[( c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no] self.scheduler.emergesend( ModuleNotification( self.getServiceName(), 'update', datapathid=c.openflow_datapathid, connection=c, vhost=c.protocol.vhost, add=[], remove=[m.desc], reason='delete')) except KeyError: pass elif m.reason == c.openflowdef.OFPPR_MODIFY: try: self.scheduler.emergesend( ModuleNotification( self.getServiceName(), 'modified', datapathid=c.openflow_datapathid, connection=c, vhost=c.protocol.vhost, old=self.managed_ports[( c.protocol.vhost, c.openflow_datapathid )][m.desc.port_no], new=m.desc, reason='modified')) except KeyError: self.scheduler.emergesend( ModuleNotification( self.getServiceName(), 'update', datapathid=c.openflow_datapathid, connection=c, vhost=c.protocol.vhost, add=[m.desc], remove=[], reason='add')) self.managed_ports[(c.protocol.vhost, c.openflow_datapathid )][m.desc.port_no] = m.desc else: e = self.apiroutine.event for c in e.remove: if c.openflow_auxiliaryid == 0 and ( c.protocol.vhost, c.openflow_datapathid) in self.managed_ports: self.scheduler.emergesend( ModuleNotification( self.getServiceName(), 'update', datapathid=c.openflow_datapathid, connection=c, vhost=c.protocol.vhost, add=[], remove=list(self.managed_ports[( c.protocol.vhost, c.openflow_datapathid)].values()), reason='disconnected')) del self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)] for c in e.add: if c.openflow_auxiliaryid == 0: self.apiroutine.subroutine( self._get_ports(c, c.protocol, True, True)) finally: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'unsynchronized')) def getports(self, datapathid, vhost=''): "Return all ports of a specifed datapath" for m in self._wait_for_sync(): yield m r = self.managed_ports.get((vhost, datapathid)) if r is None: self.apiroutine.retvalue = None else: self.apiroutine.retvalue = list(r.values()) def getallports(self, vhost=None): "Return all ``(datapathid, port, vhost)`` tuples, optionally filterd by vhost" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = [ (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items() for p in v.values() ] else: self.apiroutine.retvalue = [ (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items() if vh == vhost for p in v.values() ] def getportbyno(self, datapathid, portno, vhost=''): "Return port with specified OpenFlow portno" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost) def _getportbyno(self, datapathid, portno, vhost=''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: return ports.get(portno) def waitportbyno(self, datapathid, portno, timeout=30, vhost=''): """ Wait for the specified OpenFlow portno to appear, or until timeout. """ for m in self._wait_for_sync(): yield m def waitinner(): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: for m in callAPI(self.apiroutine, 'openflowmanager', 'waitconnection', { 'datapathid': datapathid, 'vhost': vhost, 'timeout': timeout }): yield m c = self.apiroutine.retvalue ports = self.managed_ports.get((vhost, datapathid)) if ports is None: yield (OpenflowPortSynchronized.createMatcher(c), ) ports = self.managed_ports.get((vhost, datapathid)) if ports is None: raise ConnectionResetException( 'Datapath %016x is not connected' % datapathid) if portno not in ports: yield (OpenflowAsyncMessageEvent.createMatcher( of13.OFPT_PORT_STATUS, datapathid, 0, _ismatch=lambda x: x.message.desc.port_no == portno), ) self.apiroutine.retvalue = self.apiroutine.event.message.desc else: self.apiroutine.retvalue = ports[portno] for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OpenflowPortNotAppearException( 'Port %d does not appear on datapath %016x' % (portno, datapathid)) def getportbyname(self, datapathid, name, vhost=''): "Return port with specified port name" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost) def _getportbyname(self, datapathid, name, vhost=''): if not isinstance(name, bytes): name = _bytes(name) ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for p in ports.values(): if p.name == name: return p return None def waitportbyname(self, datapathid, name, timeout=30, vhost=''): """ Wait for a port with the specified port name to appear, or until timeout """ for m in self._wait_for_sync(): yield m if not isinstance(name, bytes): name = _bytes(name) def waitinner(): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: for m in callAPI(self.apiroutine, 'openflowmanager', 'waitconnection', { 'datapathid': datapathid, 'vhost': vhost, 'timeout': timeout }): yield m c = self.apiroutine.retvalue ports = self.managed_ports.get((vhost, datapathid)) if ports is None: yield (OpenflowPortSynchronized.createMatcher(c), ) ports = self.managed_ports.get((vhost, datapathid)) if ports is None: raise ConnectionResetException( 'Datapath %016x is not connected' % datapathid) for p in ports.values(): if p.name == name: self.apiroutine.retvalue = p return yield (OpenflowAsyncMessageEvent.createMatcher( of13.OFPT_PORT_STATUS, datapathid, 0, _ismatch=lambda x: x.message.desc.name == name), ) self.apiroutine.retvalue = self.apiroutine.event.message.desc for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OpenflowPortNotAppearException( 'Port %r does not appear on datapath %016x' % (name, datapathid)) def resync(self, datapathid, vhost=''): ''' Resync with current ports ''' # Sometimes when the OpenFlow connection is very busy, PORT_STATUS message may be dropped. # We must deal with this and recover from it # Save current manged_ports if (vhost, datapathid) not in self.managed_ports: self.apiroutine.retvalue = None return else: last_ports = set(self.managed_ports[(vhost, datapathid)].keys()) add = set() remove = set() ports = {} for _ in range(0, 10): for m in callAPI(self.apiroutine, 'openflowmanager', 'getconnection', { 'datapathid': datapathid, 'vhost': vhost }): yield m c = self.apiroutine.retvalue if c is None: # Disconnected, will automatically resync when reconnected self.apiroutine.retvalue = None return ofdef = c.openflowdef protocol = c.protocol try: if hasattr(ofdef, 'ofp_multipart_request'): # Openflow 1.3, use ofp_multipart_request to get ports for m in protocol.querymultipart( ofdef.ofp_multipart_request( type=ofdef.OFPMP_PORT_DESC), c, self.apiroutine): yield m for msg in self.apiroutine.openflow_reply: for p in msg.ports: ports[p.port_no] = p else: # Openflow 1.0, use features_request request = ofdef.ofp_msg() request.header.type = ofdef.OFPT_FEATURES_REQUEST for m in protocol.querywithreply(request): yield m reply = self.apiroutine.retvalue for p in reply.ports: ports[p.port_no] = p except ConnectionResetException: break except OpenflowProtocolException: break else: if (vhost, datapathid) not in self.managed_ports: self.apiroutine.retvalue = None return current_ports = set(self.managed_ports[(vhost, datapathid)]) # If a port is already removed remove.intersection_update(current_ports) # If a port is already added add.difference_update(current_ports) # If a port is not acquired, we do not add it acquired_keys = set(ports.keys()) add.difference_update(acquired_keys) # Add and remove previous added/removed ports current_ports.difference_update(remove) current_ports.update(add) # If there are changed ports, the changed ports may or may not appear in the acquired port list # We only deal with following situations: # 1. If both lack ports, we add them # 2. If both have additional ports, we remote them to_add = acquired_keys.difference( current_ports.union(last_ports)) to_remove = current_ports.intersection(last_ports).difference( acquired_keys) if not to_add and not to_remove and current_ports == last_ports: break else: add.update(to_add) remove.update(to_remove) current_ports.update(to_add) current_ports.difference_update(to_remove) last_ports = current_ports # Actual add and remove mports = self.managed_ports[(vhost, datapathid)] add_ports = [] remove_ports = [] for k in add: if k not in mports: add_ports.append(ports[k]) mports[k] = ports[k] for k in remove: try: oldport = mports.pop(k) except KeyError: pass else: remove_ports.append(oldport) for m in self.apiroutine.waitForSend( ModuleNotification(self.getServiceName(), 'update', datapathid=datapathid, connection=c, vhost=vhost, add=add_ports, remove=remove_ports, reason='resync')): yield m self.apiroutine.retvalue = None
class OVSDBPortManager(Module): ''' Manage Ports from OVSDB Protocol ''' service = True def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_ports self.routines.append(self.apiroutine) self.managed_ports = {} self.managed_ids = {} self.monitor_routines = set() self.ports_uuids = {} self.wait_portnos = {} self.wait_names = {} self.wait_ids = {} self.bridge_datapathid = {} self.createAPI(api(self.getports, self.apiroutine), api(self.getallports, self.apiroutine), api(self.getportbyid, self.apiroutine), api(self.waitportbyid, self.apiroutine), api(self.getportbyname, self.apiroutine), api(self.waitportbyname, self.apiroutine), api(self.getportbyno, self.apiroutine), api(self.waitportbyno, self.apiroutine), api(self.resync, self.apiroutine)) self._synchronized = False def _get_interface_info(self, connection, protocol, buuid, interface_uuid, port_uuid): try: method, params = ovsdb.transact( 'Open_vSwitch', ovsdb.wait( 'Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ofport"], [{ "ofport": ovsdb.oset() }], False, 5000), ovsdb.wait( 'Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ofport"], [{ "ofport": -1 }], False, 0), ovsdb.wait( 'Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ifindex"], [{ "ifindex": ovsdb.oset() }], False, 5000), ovsdb.select( 'Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], [ "_uuid", "name", "ifindex", "ofport", "type", "external_ids" ])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException( 'Error while acquiring interface: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[1] if 'error' in r: raise JsonRPCErrorResultException( 'Error while acquiring interface: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[2] if 'error' in r: # Ignore this port because it is in an error state self.apiroutine.retvalue = [] return r = self.apiroutine.jsonrpc_result[3] if 'error' in r: raise JsonRPCErrorResultException( 'Error while acquiring interface: ' + repr(r['error'])) if not r['rows']: self.apiroutine.retvalue = [] return r0 = r['rows'][0] if r0['ofport'] < 0: # Ignore this port because it is in an error state self.apiroutine.retvalue = [] return r0['_uuid'] = r0['_uuid'][1] r0['ifindex'] = ovsdb.getoptional(r0['ifindex']) r0['external_ids'] = ovsdb.getdict(r0['external_ids']) if buuid not in self.bridge_datapathid: self.apiroutine.retvalue = [] return else: datapath_id = self.bridge_datapathid[buuid] if 'iface-id' in r0['external_ids']: eid = r0['external_ids']['iface-id'] r0['id'] = eid id_ports = self.managed_ids.setdefault((protocol.vhost, eid), []) id_ports.append((datapath_id, r0)) else: r0['id'] = None self.managed_ports.setdefault((protocol.vhost, datapath_id), []).append((port_uuid, r0)) notify = False if (protocol.vhost, datapath_id, r0['ofport']) in self.wait_portnos: notify = True del self.wait_portnos[(protocol.vhost, datapath_id, r0['ofport'])] if (protocol.vhost, datapath_id, r0['name']) in self.wait_names: notify = True del self.wait_names[(protocol.vhost, datapath_id, r0['name'])] if (protocol.vhost, r0['id']) in self.wait_ids: notify = True del self.wait_ids[(protocol.vhost, r0['id'])] if notify: for m in self.apiroutine.waitForSend( OVSDBPortUpNotification(connection, r0['name'], r0['ofport'], r0['id'], protocol.vhost, datapath_id, port=r0)): yield m self.apiroutine.retvalue = [r0] except JsonRPCProtocolException: self.apiroutine.retvalue = [] def _remove_interface_id(self, connection, protocol, datapath_id, port): eid = port['id'] eid_list = self.managed_ids.get((protocol.vhost, eid)) for i in range(0, len(eid_list)): if eid_list[i][1]['_uuid'] == port['_uuid']: del eid_list[i] break def _remove_interface(self, connection, protocol, datapath_id, interface_uuid, port_uuid): ports = self.managed_ports.get((protocol.vhost, datapath_id)) r = None if ports is not None: for i in range(0, len(ports)): if ports[i][1]['_uuid'] == interface_uuid: r = ports[i][1] if r['id']: self._remove_interface_id(connection, protocol, datapath_id, r) del ports[i] break if not ports: del self.managed_ports[(protocol.vhost, datapath_id)] return r def _remove_all_interface(self, connection, protocol, datapath_id, port_uuid, buuid): ports = self.managed_ports.get((protocol.vhost, datapath_id)) if ports is not None: removed_ports = [r for puuid, r in ports if puuid == port_uuid] not_removed_ports = [(puuid, r) for puuid, r in ports if puuid != port_uuid] ports[:len(not_removed_ports)] = not_removed_ports del ports[len(not_removed_ports):] for r in removed_ports: if r['id']: self._remove_interface_id(connection, protocol, datapath_id, r) if not ports: del self.managed_ports[(protocol.vhost, datapath_id)] return removed_ports if port_uuid in self.ports_uuids and self.ports_uuids[ port_uuid] == buuid: del self.ports_uuids[port_uuid] return [] def _update_interfaces(self, connection, protocol, updateinfo, update=True): """ There are several kinds of updates, they may appear together: 1. New bridge created (or from initial updateinfo). We should add all the interfaces to the list. 2. Bridge removed. Remove all the ports. 3. name and datapath_id may be changed. We will consider this as a new bridge created, and an old bridge removed. 4. Bridge ports modification, i.e. add/remove ports. a) Normally a port record is created/deleted together. A port record cannot exist without a bridge containing it. b) It is also possible that a port is removed from one bridge and added to another bridge, in this case the ports do not appear in the updateinfo 5. Port interfaces modification, i.e. add/remove interfaces. The bridge record may not appear in this situation. We must consider these situations carefully and process them in correct order. """ port_update = updateinfo.get('Port', {}) bridge_update = updateinfo.get('Bridge', {}) working_routines = [] def process_bridge(buuid, uo): try: nv = uo['new'] if 'datapath_id' in nv: if ovsdb.getoptional(nv['datapath_id']) is None: # This bridge is not initialized. Wait for the bridge to be initialized. for m in callAPI( self.apiroutine, 'ovsdbmanager', 'waitbridge', { 'connection': connection, 'name': nv['name'], 'timeout': 5 }): yield m datapath_id = self.apiroutine.retvalue else: datapath_id = int(nv['datapath_id'], 16) self.bridge_datapathid[buuid] = datapath_id elif buuid in self.bridge_datapathid: datapath_id = self.bridge_datapathid[buuid] else: # This should not happen, but just in case... for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitbridge', { 'connection': connection, 'name': nv['name'], 'timeout': 5 }): yield m datapath_id = self.apiroutine.retvalue self.bridge_datapathid[buuid] = datapath_id if 'ports' in nv: nset = set((p for _, p in ovsdb.getlist(nv['ports']))) else: nset = set() if 'old' in uo: ov = uo['old'] if 'ports' in ov: oset = set((p for _, p in ovsdb.getlist(ov['ports']))) else: # new ports are not really added; it is only sent because datapath_id is modified nset = set() oset = set() if 'datapath_id' in ov and ovsdb.getoptional( ov['datapath_id']) is not None: old_datapathid = int(ov['datapath_id'], 16) else: old_datapathid = datapath_id else: oset = set() old_datapathid = datapath_id # For every deleted port, remove the interfaces with this port _uuid remove = [] add_routine = [] for puuid in oset - nset: remove += self._remove_all_interface( connection, protocol, old_datapathid, puuid, buuid) # For every port not changed, check if the interfaces are modified; for puuid in oset.intersection(nset): if puuid in port_update: # The port is modified, there should be an 'old' set and 'new' set pu = port_update[puuid] if 'old' in pu: poset = set((p for _, p in ovsdb.getlist( pu['old']['interfaces']))) else: poset = set() if 'new' in pu: pnset = set((p for _, p in ovsdb.getlist( pu['new']['interfaces']))) else: pnset = set() # Remove old interfaces remove += [ r for r in (self._remove_interface( connection, protocol, datapath_id, iuuid, puuid) for iuuid in (poset - pnset)) if r is not None ] # Prepare to add new interfaces add_routine += [ self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for iuuid in (pnset - poset) ] # For every port added, add the interfaces def add_port_interfaces(puuid): # If the uuid does not appear in update info, we have no choice but to query interfaces with select # we cannot use data from other bridges; the port may be moved from a bridge which is not tracked try: method, params = ovsdb.transact( 'Open_vSwitch', ovsdb.select('Port', [["_uuid", "==", ovsdb.uuid(puuid)]], ["interfaces"])) for m in protocol.querywithreply( method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException( 'Error when query interfaces from port ' + repr(puuid) + ': ' + r['error']) if r['rows']: interfaces = ovsdb.getlist( r['rows'][0]['interfaces']) with closing( self.apiroutine.executeAll([ self._get_interface_info( connection, protocol, buuid, iuuid, puuid) for _, iuuid in interfaces ])) as g: for m in g: yield m self.apiroutine.retvalue = list( itertools.chain( r[0] for r in self.apiroutine.retvalue)) else: self.apiroutine.retvalue = [] except JsonRPCProtocolException: self.apiroutine.retvalue = [] except ConnectionResetException: self.apiroutine.retvalue = [] for puuid in nset - oset: self.ports_uuids[puuid] = buuid if puuid in port_update and 'new' in port_update[puuid] \ and 'old' not in port_update[puuid]: # Add all the interfaces in 'new' interfaces = ovsdb.getlist( port_update[puuid]['new']['interfaces']) add_routine += [ self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for _, iuuid in interfaces ] else: add_routine.append(add_port_interfaces(puuid)) # Execute the add_routine try: with closing(self.apiroutine.executeAll(add_routine)) as g: for m in g: yield m except: add = [] raise else: add = list( itertools.chain(r[0] for r in self.apiroutine.retvalue)) finally: if update: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', datapathid=datapath_id, connection=connection, vhost=protocol.vhost, add=add, remove=remove, reason='bridgemodify' if 'old' in uo else 'bridgeup')) except JsonRPCProtocolException: pass except ConnectionResetException: pass except OVSDBBridgeNotAppearException: pass ignore_ports = set() for buuid, uo in bridge_update.items(): # Bridge removals are ignored because we process OVSDBBridgeSetup event instead if 'old' in uo: if 'ports' in uo['old']: oset = set( (puuid for _, puuid in ovsdb.getlist(uo['old']['ports']))) ignore_ports.update(oset) if 'new' not in uo: if buuid in self.bridge_datapathid: del self.bridge_datapathid[buuid] if 'new' in uo: # If bridge contains this port is updated, we process the port update totally in bridge, # so we ignore it later if 'ports' in uo['new']: nset = set( (puuid for _, puuid in ovsdb.getlist(uo['new']['ports']))) ignore_ports.update(nset) working_routines.append(process_bridge(buuid, uo)) def process_port(buuid, port_uuid, interfaces, remove_ids): if buuid not in self.bridge_datapathid: return datapath_id = self.bridge_datapathid[buuid] ports = self.managed_ports.get((protocol.vhost, datapath_id)) remove = [] if ports is not None: remove = [p for _, p in ports if p['_uuid'] in remove_ids] not_remove = [(_, p) for _, p in ports if p['_uuid'] not in remove_ids] ports[:len(not_remove)] = not_remove del ports[len(not_remove):] if interfaces: try: with closing( self.apiroutine.executeAll([ self._get_interface_info( connection, protocol, buuid, iuuid, port_uuid) for iuuid in interfaces ])) as g: for m in g: yield m add = list( itertools.chain( (r[0] for r in self.apiroutine.retvalue if r[0]))) except Exception: self._logger.warning("Cannot get new port information", exc_info=True) add = [] else: add = [] if update: for m in self.apiroutine.waitForSend( ModuleNotification(self.getServiceName(), 'update', datapathid=datapath_id, connection=connection, vhost=protocol.vhost, add=add, remove=remove, reason='bridgemodify')): yield m for puuid, po in port_update.items(): if puuid not in ignore_ports: bridge_id = self.ports_uuids.get(puuid) if bridge_id is not None: datapath_id = self.bridge_datapathid[bridge_id] if datapath_id is not None: # This port is modified if 'new' in po: nset = set((iuuid for _, iuuid in ovsdb.getlist( po['new']['interfaces']))) else: nset = set() if 'old' in po: oset = set((iuuid for _, iuuid in ovsdb.getlist( po['old']['interfaces']))) else: oset = set() working_routines.append( process_port(bridge_id, puuid, nset - oset, oset - nset)) if update: for r in working_routines: self.apiroutine.subroutine(r) else: try: with closing( self.apiroutine.executeAll(working_routines, None, ())) as g: for m in g: yield m finally: self.scheduler.emergesend( OVSDBConnectionPortsSynchronized(connection)) def _get_ports(self, connection, protocol): try: try: method, params = ovsdb.monitor( 'Open_vSwitch', 'ovsdb_port_manager_interfaces_monitor', { 'Bridge': [ ovsdb.monitor_request( ["name", "datapath_id", "ports"]) ], 'Port': [ovsdb.monitor_request(["interfaces"])] }) try: for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m except JsonRPCErrorResultException: # The monitor is already set, cancel it first method2, params2 = ovsdb.monitor_cancel( 'ovsdb_port_manager_interfaces_monitor') for m in protocol.querywithreply(method2, params2, connection, self.apiroutine, False): yield m for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result except: def _msg(): for m in self.apiroutine.waitForSend( OVSDBConnectionPortsSynchronized(connection)): yield m self.apiroutine.subroutine(_msg(), False) raise # This is the initial state, it should contains all the ids of ports and interfaces self.apiroutine.subroutine( self._update_interfaces(connection, protocol, r, False)) update_matcher = JsonRPCNotificationEvent.createMatcher( 'update', connection, connection.connmark, _ismatch=lambda x: x.params[ 0] == 'ovsdb_port_manager_interfaces_monitor') conn_state = protocol.statematcher(connection) while True: yield (update_matcher, conn_state) if self.apiroutine.matcher is conn_state: break else: self.apiroutine.subroutine( self._update_interfaces( connection, protocol, self.apiroutine.event.params[1], True)) except JsonRPCProtocolException: pass finally: if self.apiroutine.currentroutine in self.monitor_routines: self.monitor_routines.remove(self.apiroutine.currentroutine) def _get_existing_ports(self): for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getallconnections', {'vhost': None}): yield m matchers = [] for c in self.apiroutine.retvalue: self.monitor_routines.add( self.apiroutine.subroutine(self._get_ports(c, c.protocol))) matchers.append(OVSDBConnectionPortsSynchronized.createMatcher(c)) for m in self.apiroutine.waitForAll(*matchers): yield m self._synchronized = True for m in self.apiroutine.waitForSend( ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'), ) def _manage_ports(self): try: self.apiroutine.subroutine(self._get_existing_ports()) connsetup = OVSDBConnectionSetup.createMatcher() bridgedown = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.DOWN) while True: yield (connsetup, bridgedown) e = self.apiroutine.event if self.apiroutine.matcher is connsetup: self.monitor_routines.add( self.apiroutine.subroutine( self._get_ports(e.connection, e.connection.protocol))) else: # Remove ports of the bridge ports = self.managed_ports.get((e.vhost, e.datapathid)) if ports is not None: ports_original = ports ports = [p for _, p in ports] for p in ports: if p['id']: self._remove_interface_id( e.connection, e.connection.protocol, e.datapathid, p) newdpid = getattr(e, 'new_datapath_id', None) buuid = e.bridgeuuid if newdpid is not None: # This bridge changes its datapath id if buuid in self.bridge_datapathid and self.bridge_datapathid[ buuid] == e.datapathid: self.bridge_datapathid[buuid] = newdpid def re_add_interfaces(): with closing( self.apiroutine.executeAll([ self._get_interface_info( e.connection, e.connection.protocol, buuid, r['_uuid'], puuid) for puuid, r in ports_original ])) as g: for m in g: yield m add = list( itertools.chain( r[0] for r in self.apiroutine.retvalue)) for m in self.apiroutine.waitForSend( ModuleNotification( self.getServiceName(), 'update', datapathid=e.datapathid, connection=e.connection, vhost=e.vhost, add=add, remove=[], reason='bridgeup')): yield m self.apiroutine.subroutine(re_add_interfaces()) else: # The ports are removed for puuid, _ in ports_original: if puuid in self.ports_uuids[ puuid] and self.ports_uuids[ puuid] == buuid: del self.ports_uuids[puuid] del self.managed_ports[(e.vhost, e.datapathid)] self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', datapathid=e.datapathid, connection=e.connection, vhost=e.vhost, add=[], remove=ports, reason='bridgedown')) finally: for r in list(self.monitor_routines): r.close() self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'unsynchronized')) def getports(self, datapathid, vhost=''): "Return all ports of a specifed datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ p for _, p in self.managed_ports.get((vhost, datapathid), []) ] def getallports(self, vhost=None): "Return all ``(datapathid, port, vhost)`` tuples, optionally filterd by vhost" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = [ (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items() for _, p in v ] else: self.apiroutine.retvalue = [ (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items() if vh == vhost for _, p in v ] def getportbyno(self, datapathid, portno, vhost=''): "Return port with specified portno" portno &= 0xffff for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost) def _getportbyno(self, datapathid, portno, vhost=''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for _, p in ports: if p['ofport'] == portno: return p return None def waitportbyno(self, datapathid, portno, timeout=30, vhost=''): "Wait for port with specified portno" portno &= 0xffff for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyno(datapathid, portno, vhost) if p is not None: self.apiroutine.retvalue = p else: try: self.wait_portnos[(vhost, datapathid, portno)] = \ self.wait_portnos.get((vhost, datapathid, portno),0) + 1 yield (OVSDBPortUpNotification.createMatcher( None, None, portno, None, vhost, datapathid), ) except: v = self.wait_portnos.get((vhost, datapathid, portno)) if v is not None: if v <= 1: del self.wait_portnos[(vhost, datapathid, portno)] else: self.wait_portnos[(vhost, datapathid, portno)] = v - 1 raise else: self.apiroutine.retvalue = self.apiroutine.event.port for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException( 'Port ' + repr(portno) + ' does not appear before timeout') def getportbyname(self, datapathid, name, vhost=''): "Return port with specified name" if isinstance(name, bytes): name = name.decode('utf-8') for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost) def _getportbyname(self, datapathid, name, vhost=''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for _, p in ports: if p['name'] == name: return p return None def waitportbyname(self, datapathid, name, timeout=30, vhost=''): "Wait for port with specified name" for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyname(datapathid, name, vhost) if p is not None: self.apiroutine.retvalue = p else: try: self.wait_names[(vhost, datapathid, name)] = \ self.wait_portnos.get((vhost, datapathid, name) ,0) + 1 yield (OVSDBPortUpNotification.createMatcher( None, name, None, None, vhost, datapathid), ) except: v = self.wait_names.get((vhost, datapathid, name)) if v is not None: if v <= 1: del self.wait_names[(vhost, datapathid, name)] else: self.wait_names[(vhost, datapathid, name)] = v - 1 raise else: self.apiroutine.retvalue = self.apiroutine.event.port for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException( 'Port ' + repr(name) + ' does not appear before timeout') def getportbyid(self, id, vhost=''): "Return port with the specified id. The return value is a pair: ``(datapath_id, port)``" for m in self._wait_for_sync(): yield m self.apiroutine = self._getportbyid(id, vhost) def _getportbyid(self, id, vhost=''): ports = self.managed_ids.get((vhost, id)) if ports: return ports[0] else: return None def waitportbyid(self, id, timeout=30, vhost=''): "Wait for port with the specified id. The return value is a pair ``(datapath_id, port)``" for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyid(id, vhost) if p is None: try: self.wait_ids[(vhost, id)] = self.wait_ids.get( (vhost, id), 0) + 1 yield (OVSDBPortUpNotification.createMatcher( None, None, None, id, vhost), ) except: v = self.wait_ids.get((vhost, id)) if v is not None: if v <= 1: del self.wait_ids[(vhost, id)] else: self.wait_ids[(vhost, id)] = v - 1 raise else: self.apiroutine.retvalue = ( self.apiroutine.event.datapathid, self.apiroutine.event.port) else: self.apiroutine.retvalue = p for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException( 'Port ' + repr(id) + ' does not appear before timeout') def resync(self, datapathid, vhost=''): ''' Resync with current ports ''' # Sometimes when the OVSDB connection is very busy, monitor message may be dropped. # We must deal with this and recover from it # Save current manged_ports if (vhost, datapathid) not in self.managed_ports: self.apiroutine.retvalue = None return else: for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getconnection', { 'datapathid': datapathid, 'vhost': vhost }): yield m c = self.apiroutine.retvalue if c is not None: # For now, we restart the connection... for m in c.reconnect(False): yield m for m in self.apiroutine.waitWithTimeout(0.1): yield m for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitconnection', { 'datapathid': datapathid, 'vhost': vhost }): yield m self.apiroutine.retvalue = None
class OpenflowManager(Module): ''' Manage Openflow Connections ''' service = True _default_vhostbind = None def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_conns self.routines.append(self.apiroutine) self.managed_conns = {} self.endpoint_conns = {} self.table_modules = set() self._acquiring = False self._acquire_updated = False self._lastacquire = None self._synchronized = False self.createAPI(api(self.getconnections, self.apiroutine), api(self.getconnection, self.apiroutine), api(self.waitconnection, self.apiroutine), api(self.getdatapathids, self.apiroutine), api(self.getalldatapathids, self.apiroutine), api(self.getallconnections, self.apiroutine), api(self.getconnectionsbyendpoint, self.apiroutine), api(self.getconnectionsbyendpointname, self.apiroutine), api(self.getendpoints, self.apiroutine), api(self.getallendpoints, self.apiroutine), api(self.acquiretable, self.apiroutine), api(self.unacquiretable, self.apiroutine), api(self.lastacquiredtables)) def _add_connection(self, conn): vhost = conn.protocol.vhost conns = self.managed_conns.setdefault( (vhost, conn.openflow_datapathid), []) remove = [] for i in range(0, len(conns)): if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid: ci = conns[i] remove = [ci] ep = _get_endpoint(ci) econns = self.endpoint_conns.get((vhost, ep)) if econns is not None: try: econns.remove(ci) except ValueError: pass if not econns: del self.endpoint_conns[(vhost, ep)] del conns[i] break conns.append(conn) ep = _get_endpoint(conn) econns = self.endpoint_conns.setdefault((vhost, ep), []) econns.append(conn) if self._lastacquire and conn.openflow_auxiliaryid == 0: self.apiroutine.subroutine(self._initialize_connection(conn)) return remove def _initialize_connection(self, conn): ofdef = conn.openflowdef flow_mod = ofdef.ofp_flow_mod(buffer_id=ofdef.OFP_NO_BUFFER, out_port=ofdef.OFPP_ANY, command=ofdef.OFPFC_DELETE) if hasattr(ofdef, 'OFPG_ANY'): flow_mod.out_group = ofdef.OFPG_ANY if hasattr(ofdef, 'OFPTT_ALL'): flow_mod.table_id = ofdef.OFPTT_ALL if hasattr(ofdef, 'ofp_match_oxm'): flow_mod.match = ofdef.ofp_match_oxm() cmds = [flow_mod] if hasattr(ofdef, 'ofp_group_mod'): group_mod = ofdef.ofp_group_mod(command=ofdef.OFPGC_DELETE, group_id=ofdef.OFPG_ALL) cmds.append(group_mod) for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m if hasattr(ofdef, 'ofp_instruction_goto_table'): # Create default flows vhost = conn.protocol.vhost if self._lastacquire and vhost in self._lastacquire: _, pathtable = self._lastacquire[vhost] cmds = [ ofdef.ofp_flow_mod(table_id=t[i][1], command=ofdef.OFPFC_ADD, priority=0, buffer_id=ofdef.OFP_NO_BUFFER, out_port=ofdef.OFPP_ANY, out_group=ofdef.OFPG_ANY, match=ofdef.ofp_match_oxm(), instructions=[ ofdef.ofp_instruction_goto_table( table_id=t[i + 1][1]) ]) for _, t in pathtable.items() for i in range(0, len(t) - 1) ] if cmds: for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m for m in self.apiroutine.waitForSend( FlowInitialize(conn, conn.openflow_datapathid, conn.protocol.vhost)): yield m def _acquire_tables(self): try: while self._acquire_updated: result = None exception = None # Delay the update so we are not updating table acquires for every module for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()): yield m yield (TableAcquireDelayEvent.createMatcher(), ) module_list = list(self.table_modules) self._acquire_updated = False try: for m in self.apiroutine.executeAll( (callAPI(self.apiroutine, module, 'gettablerequest', {}) for module in module_list)): yield m except QuitException: raise except Exception as exc: self._logger.exception('Acquiring table failed') exception = exc else: requests = [r[0] for r in self.apiroutine.retvalue] vhosts = set(vh for _, vhs in requests if vhs is not None for vh in vhs) vhost_result = {} # Requests should be list of (name, (ancester, ancester, ...), pathname) for vh in vhosts: graph = {} table_path = {} try: for r in requests: if r[1] is None or vh in r[1]: for name, ancesters, pathname in r[0]: if name in table_path: if table_path[name] != pathname: raise ValueError( "table conflict detected: %r can not be in two path: %r and %r" % (name, table_path[name], pathname)) else: table_path[name] = pathname if name not in graph: graph[name] = (set(ancesters), set()) else: graph[name][0].update(ancesters) for anc in ancesters: graph.setdefault( anc, (set(), set()))[1].add(name) except ValueError as exc: self._logger.error(str(exc)) exception = exc break else: sequences = [] def dfs_sort(current): sequences.append(current) for d in graph[current][1]: anc = graph[d][0] anc.remove(current) if not anc: dfs_sort(d) nopre_tables = sorted( [k for k, v in graph.items() if not v[0]], key=lambda x: (table_path.get(name, ''), name)) for t in nopre_tables: dfs_sort(t) if len(sequences) < len(graph): rest_tables = set( graph.keys()).difference(sequences) self._logger.error( "Circle detected in table acquiring, following tables are related: %r, vhost = %r", sorted(rest_tables), vh) self._logger.error( "Circle dependencies are: %s", ", ".join( repr(tuple(graph[t][0])) + "=>" + t for t in rest_tables)) exception = ValueError( "Circle detected in table acquiring, following tables are related: %r, vhost = %r" % (sorted(rest_tables), vh)) break elif len(sequences) > 255: self._logger.error( "Table limit exceeded: %d tables (only 255 allowed), vhost = %r", len(sequences), vh) exception = ValueError( "Table limit exceeded: %d tables (only 255 allowed), vhost = %r" % (len(sequences), vh)) break else: full_indices = list( zip(sequences, itertools.count())) tables = dict( (k, tuple(g)) for k, g in itertools.groupby( sorted(full_indices, key=lambda x: table_path.get( x[0], '')), lambda x: table_path.get(x[0], ''))) vhost_result[vh] = (full_indices, tables) finally: self._acquiring = False if exception: for m in self.apiroutine.waitForSend( TableAcquireUpdate(exception=exception)): yield m else: result = vhost_result if result != self._lastacquire: self._lastacquire = result self._reinitall() for m in self.apiroutine.waitForSend( TableAcquireUpdate(result=result)): yield m def load(self, container): self.scheduler.queue.addSubQueue( 1, TableAcquireDelayEvent.createMatcher(), 'ofpmanager_tableacquiredelay') for m in container.waitForSend(TableAcquireUpdate(result=None)): yield m for m in Module.load(self, container): yield m def unload(self, container, force=False): for m in Module.unload(self, container, force=force): yield m for m in container.syscall( syscall_removequeue(self.scheduler.queue, 'ofpmanager_tableacquiredelay')): yield m def _reinitall(self): for cl in self.managed_conns.values(): for c in cl: self.apiroutine.subroutine(self._initialize_connection(c)) def _manage_existing(self): for m in callAPI(self.apiroutine, "openflowserver", "getconnections", {}): yield m vb = self.vhostbind for c in self.apiroutine.retvalue: if vb is None or c.protocol.vhost in vb: self._add_connection(c) self._synchronized = True for m in self.apiroutine.waitForSend( ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'), ) def _manage_conns(self): vb = self.vhostbind self.apiroutine.subroutine(self._manage_existing(), False) try: if vb is not None: conn_up = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch=lambda x: x.createby.vhost in vb) conn_down = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_DOWN, _ismatch=lambda x: x.createby.vhost in vb) else: conn_up = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_SETUP) conn_down = OpenflowConnectionStateEvent.createMatcher( state=OpenflowConnectionStateEvent.CONNECTION_DOWN) while True: yield (conn_up, conn_down) if self.apiroutine.matcher is conn_up: e = self.apiroutine.event remove = self._add_connection(e.connection) self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', add=[e.connection], remove=remove)) else: e = self.apiroutine.event conns = self.managed_conns.get( (e.createby.vhost, e.datapathid)) remove = [] if conns is not None: try: conns.remove(e.connection) except ValueError: pass else: remove.append(e.connection) if not conns: del self.managed_conns[(e.createby.vhost, e.datapathid)] # Also delete from endpoint_conns ep = _get_endpoint(e.connection) econns = self.endpoint_conns.get( (e.createby.vhost, ep)) if econns is not None: try: econns.remove(e.connection) except ValueError: pass if not econns: del self.endpoint_conns[(e.createby.vhost, ep)] if remove: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', add=[], remove=remove)) finally: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'unsynchronized')) def getconnections(self, datapathid, vhost=''): "Return all connections of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list( self.managed_conns.get((vhost, datapathid), [])) def getconnection(self, datapathid, auxiliaryid=0, vhost=''): "Get current connection of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid, vhost) def _getconnection(self, datapathid, auxiliaryid=0, vhost=''): conns = self.managed_conns.get((vhost, datapathid)) if conns is None: return None else: for c in conns: if c.openflow_auxiliaryid == auxiliaryid: return c return None def waitconnection(self, datapathid, auxiliaryid=0, timeout=30, vhost=''): "Wait for a datapath connection" for m in self._wait_for_sync(): yield m c = self._getconnection(datapathid, auxiliaryid, vhost) if c is None: for m in self.apiroutine.waitWithTimeout( timeout, OpenflowConnectionStateEvent.createMatcher( datapathid, auxiliaryid, OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch=lambda x: x.createby.vhost == vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException( 'Datapath %016x is not connected' % datapathid) self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getdatapathids(self, vhost=''): "Get All datapath IDs" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ k[1] for k in self.managed_conns.keys() if k[0] == vhost ] def getalldatapathids(self): "Get all datapath IDs from any vhost. Return ``(vhost, datapathid)`` pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.keys()) def getallconnections(self, vhost=''): "Get all connections from vhost. If vhost is None, return all connections from any host" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = list( itertools.chain(self.managed_conns.values())) else: self.apiroutine.retvalue = list( itertools.chain(v for k, v in self.managed_conns.items() if k[0] == vhost)) def getconnectionsbyendpoint(self, endpoint, vhost=''): "Get connection by endpoint address (IP, IPv6 or UNIX socket address)" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint)) def getconnectionsbyendpointname(self, name, vhost='', timeout=30): "Get connection by endpoint name (Domain name, IP or IPv6 address)" # Resolve the name if not name: endpoint = '' for m in self.getconnectionbyendpoint(endpoint, vhost): yield m else: request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) # Resolve hostname for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)): yield m for m in self.apiroutine.waitWithTimeout( timeout, ResolveResponseEvent.createMatcher(request)): yield m if self.apiroutine.timeout: # Resolve is only allowed through asynchronous resolver #try: # self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST) #except: raise IOError('Resolve hostname timeout: ' + name) else: if hasattr(self.apiroutine.event, 'error'): raise IOError('Cannot resolve hostname: ' + name) resp = self.apiroutine.event.response for r in resp: raddr = r[4] if isinstance(raddr, tuple): # Ignore port endpoint = raddr[0] else: # Unix socket? This should not happen, but in case... endpoint = raddr for m in self.getconnectionsbyendpoint(endpoint, vhost): yield m if self.apiroutine.retvalue is not None: break def getendpoints(self, vhost=''): "Get all endpoints for vhost" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [ k[1] for k in self.endpoint_conns if k[0] == vhost ] def getallendpoints(self): "Get all endpoints from any vhost. Return ``(vhost, endpoint)`` pairs." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.endpoint_conns.keys()) def lastacquiredtables(self, vhost=""): "Get acquired table IDs" return self._lastacquire.get(vhost) def acquiretable(self, modulename): "Start to acquire tables for a module on module loading." if not modulename in self.table_modules: self.table_modules.add(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield def unacquiretable(self, modulename): "When module is unloaded, stop acquiring tables for this module." if modulename in self.table_modules: self.table_modules.remove(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield
class OpenflowManager(Module): ''' Manage Openflow Connections ''' service = True _default_vhostbind = None def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_conns self.routines.append(self.apiroutine) self.managed_conns = {} self.endpoint_conns = {} self.table_modules = set() self._acquiring = False self._acquire_updated = False self._lastacquire = None self._synchronized = False self.createAPI(api(self.getconnections, self.apiroutine), api(self.getconnection, self.apiroutine), api(self.waitconnection, self.apiroutine), api(self.getdatapathids, self.apiroutine), api(self.getalldatapathids, self.apiroutine), api(self.getallconnections, self.apiroutine), api(self.getconnectionsbyendpoint, self.apiroutine), api(self.getconnectionsbyendpointname, self.apiroutine), api(self.getendpoints, self.apiroutine), api(self.getallendpoints, self.apiroutine), api(self.acquiretable, self.apiroutine), api(self.unacquiretable, self.apiroutine), api(self.lastacquiredtables) ) def _add_connection(self, conn): vhost = conn.protocol.vhost conns = self.managed_conns.setdefault((vhost, conn.openflow_datapathid), []) remove = [] for i in range(0, len(conns)): if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid: ci = conns[i] remove = [ci] ep = _get_endpoint(ci) econns = self.endpoint_conns.get((vhost, ep)) if econns is not None: try: econns.remove(ci) except ValueError: pass if not econns: del self.endpoint_conns[(vhost, ep)] del conns[i] break conns.append(conn) ep = _get_endpoint(conn) econns = self.endpoint_conns.setdefault((vhost, ep), []) econns.append(conn) if self._lastacquire and conn.openflow_auxiliaryid == 0: self.apiroutine.subroutine(self._initialize_connection(conn)) return remove def _initialize_connection(self, conn): ofdef = conn.openflowdef flow_mod = ofdef.ofp_flow_mod(buffer_id = ofdef.OFP_NO_BUFFER, out_port = ofdef.OFPP_ANY, command = ofdef.OFPFC_DELETE ) if hasattr(ofdef, 'OFPG_ANY'): flow_mod.out_group = ofdef.OFPG_ANY if hasattr(ofdef, 'OFPTT_ALL'): flow_mod.table_id = ofdef.OFPTT_ALL if hasattr(ofdef, 'ofp_match_oxm'): flow_mod.match = ofdef.ofp_match_oxm() cmds = [flow_mod] if hasattr(ofdef, 'ofp_group_mod'): group_mod = ofdef.ofp_group_mod(command = ofdef.OFPGC_DELETE, group_id = ofdef.OFPG_ALL ) cmds.append(group_mod) for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m if hasattr(ofdef, 'ofp_instruction_goto_table'): # Create default flows vhost = conn.protocol.vhost if self._lastacquire and vhost in self._lastacquire: _, pathtable = self._lastacquire[vhost] cmds = [ofdef.ofp_flow_mod(table_id = t[i][1], command = ofdef.OFPFC_ADD, priority = 0, buffer_id = ofdef.OFP_NO_BUFFER, out_port = ofdef.OFPP_ANY, out_group = ofdef.OFPG_ANY, match = ofdef.ofp_match_oxm(), instructions = [ofdef.ofp_instruction_goto_table(table_id = t[i+1][1])] ) for _,t in pathtable.items() for i in range(0, len(t) - 1)] if cmds: for m in conn.protocol.batch(cmds, conn, self.apiroutine): yield m for m in self.apiroutine.waitForSend(FlowInitialize(conn, conn.openflow_datapathid, conn.protocol.vhost)): yield m def _acquire_tables(self): try: while self._acquire_updated: result = None exception = None # Delay the update so we are not updating table acquires for every module for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()): yield m yield (TableAcquireDelayEvent.createMatcher(),) module_list = list(self.table_modules) self._acquire_updated = False try: for m in self.apiroutine.executeAll((callAPI(self.apiroutine, module, 'gettablerequest', {}) for module in module_list)): yield m except QuitException: raise except Exception as exc: self._logger.exception('Acquiring table failed') exception = exc else: requests = [r[0] for r in self.apiroutine.retvalue] vhosts = set(vh for _, vhs in requests if vhs is not None for vh in vhs) vhost_result = {} # Requests should be list of (name, (ancester, ancester, ...), pathname) for vh in vhosts: graph = {} table_path = {} try: for r in requests: if r[1] is None or vh in r[1]: for name, ancesters, pathname in r[0]: if name in table_path: if table_path[name] != pathname: raise ValueError("table conflict detected: %r can not be in two path: %r and %r" % (name, table_path[name], pathname)) else: table_path[name] = pathname if name not in graph: graph[name] = (set(ancesters), set()) else: graph[name][0].update(ancesters) for anc in ancesters: graph.setdefault(anc, (set(), set()))[1].add(name) except ValueError as exc: self._logger.error(str(exc)) exception = exc break else: sequences = [] def dfs_sort(current): sequences.append(current) for d in graph[current][1]: anc = graph[d][0] anc.remove(current) if not anc: dfs_sort(d) nopre_tables = sorted([k for k,v in graph.items() if not v[0]], key = lambda x: (table_path.get(name, ''),name)) for t in nopre_tables: dfs_sort(t) if len(sequences) < len(graph): rest_tables = set(graph.keys()).difference(sequences) self._logger.error("Circle detected in table acquiring, following tables are related: %r, vhost = %r", sorted(rest_tables), vh) self._logger.error("Circle dependencies are: %s", ", ".join(repr(tuple(graph[t][0])) + "=>" + t for t in rest_tables)) exception = ValueError("Circle detected in table acquiring, following tables are related: %r, vhost = %r" % (sorted(rest_tables),vh)) break elif len(sequences) > 255: self._logger.error("Table limit exceeded: %d tables (only 255 allowed), vhost = %r", len(sequences), vh) exception = ValueError("Table limit exceeded: %d tables (only 255 allowed), vhost = %r" % (len(sequences),vh)) break else: full_indices = list(zip(sequences, itertools.count())) tables = dict((k,tuple(g)) for k,g in itertools.groupby(sorted(full_indices, key = lambda x: table_path.get(x[0], '')), lambda x: table_path.get(x[0], ''))) vhost_result[vh] = (full_indices, tables) finally: self._acquiring = False if exception: for m in self.apiroutine.waitForSend(TableAcquireUpdate(exception = exception)): yield m else: result = vhost_result if result != self._lastacquire: self._lastacquire = result self._reinitall() for m in self.apiroutine.waitForSend(TableAcquireUpdate(result = result)): yield m def load(self, container): self.scheduler.queue.addSubQueue(1, TableAcquireDelayEvent.createMatcher(), 'ofpmanager_tableacquiredelay') for m in container.waitForSend(TableAcquireUpdate(result = None)): yield m for m in Module.load(self, container): yield m def unload(self, container, force=False): for m in Module.unload(self, container, force=force): yield m for m in container.syscall(syscall_removequeue(self.scheduler.queue, 'ofpmanager_tableacquiredelay')): yield m def _reinitall(self): for cl in self.managed_conns.values(): for c in cl: self.apiroutine.subroutine(self._initialize_connection(c)) def _manage_existing(self): for m in callAPI(self.apiroutine, "openflowserver", "getconnections", {}): yield m vb = self.vhostbind for c in self.apiroutine.retvalue: if vb is None or c.protocol.vhost in vb: self._add_connection(c) self._synchronized = True for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),) def _manage_conns(self): vb = self.vhostbind self.apiroutine.subroutine(self._manage_existing(), False) try: if vb is not None: conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch = lambda x: x.createby.vhost in vb) conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN, _ismatch = lambda x: x.createby.vhost in vb) else: conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP) conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN) while True: yield (conn_up, conn_down) if self.apiroutine.matcher is conn_up: e = self.apiroutine.event remove = self._add_connection(e.connection) self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [e.connection], remove = remove)) else: e = self.apiroutine.event conns = self.managed_conns.get((e.createby.vhost, e.datapathid)) remove = [] if conns is not None: try: conns.remove(e.connection) except ValueError: pass else: remove.append(e.connection) if not conns: del self.managed_conns[(e.createby.vhost, e.datapathid)] # Also delete from endpoint_conns ep = _get_endpoint(e.connection) econns = self.endpoint_conns.get((e.createby.vhost, ep)) if econns is not None: try: econns.remove(e.connection) except ValueError: pass if not econns: del self.endpoint_conns[(e.createby.vhost, ep)] if remove: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [], remove = remove)) finally: self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized')) def getconnections(self, datapathid, vhost = ''): "Return all connections of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.get((vhost, datapathid), [])) def getconnection(self, datapathid, auxiliaryid = 0, vhost = ''): "Get current connection of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid, vhost) def _getconnection(self, datapathid, auxiliaryid = 0, vhost = ''): conns = self.managed_conns.get((vhost, datapathid)) if conns is None: return None else: for c in conns: if c.openflow_auxiliaryid == auxiliaryid: return c return None def waitconnection(self, datapathid, auxiliaryid = 0, timeout = 30, vhost = ''): "Wait for a datapath connection" for m in self._wait_for_sync(): yield m c = self._getconnection(datapathid, auxiliaryid, vhost) if c is None: for m in self.apiroutine.waitWithTimeout(timeout, OpenflowConnectionStateEvent.createMatcher(datapathid, auxiliaryid, OpenflowConnectionStateEvent.CONNECTION_SETUP, _ismatch = lambda x: x.createby.vhost == vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath %016x is not connected' % datapathid) self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getdatapathids(self, vhost = ''): "Get All datapath IDs" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost] def getalldatapathids(self): "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.keys()) def getallconnections(self, vhost = ''): "Get all connections from vhost. If vhost is None, return all connections from any host" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = list(itertools.chain(self.managed_conns.values())) else: self.apiroutine.retvalue = list(itertools.chain(v for k,v in self.managed_conns.items() if k[0] == vhost)) def getconnectionsbyendpoint(self, endpoint, vhost = ''): "Get connection by endpoint address (IP, IPv6 or UNIX socket address)" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint)) def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30): "Get connection by endpoint name (Domain name, IP or IPv6 address)" # Resolve the name if not name: endpoint = '' for m in self.getconnectionbyendpoint(endpoint, vhost): yield m else: request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) # Resolve hostname for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)): yield m for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)): yield m if self.apiroutine.timeout: # Resolve is only allowed through asynchronous resolver #try: # self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST) #except: raise IOError('Resolve hostname timeout: ' + name) else: if hasattr(self.apiroutine.event, 'error'): raise IOError('Cannot resolve hostname: ' + name) resp = self.apiroutine.event.response for r in resp: raddr = r[4] if isinstance(raddr, tuple): # Ignore port endpoint = raddr[0] else: # Unix socket? This should not happen, but in case... endpoint = raddr for m in self.getconnectionsbyendpoint(endpoint, vhost): yield m if self.apiroutine.retvalue is not None: break def getendpoints(self, vhost = ''): "Get all endpoints for vhost" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost] def getallendpoints(self): "Get all endpoints from any vhost. Return (vhost, endpoint) pairs." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.endpoint_conns.keys()) def lastacquiredtables(self, vhost = ""): "Get acquired table IDs" return self._lastacquire.get(vhost) def acquiretable(self, modulename): "Start to acquire tables for a module on module loading." if not modulename in self.table_modules: self.table_modules.add(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield def unacquiretable(self, modulename): "When module is unloaded, stop acquiring tables for this module." if modulename in self.table_modules: self.table_modules.remove(modulename) self._acquire_updated = True if not self._acquiring: self._acquiring = True self.apiroutine.subroutine(self._acquire_tables()) self.apiroutine.retvalue = None if False: yield
class TestObjectDB(Module): def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._main self.routines.append(self.apiroutine) self._reqid = 0 self._ownerid = uuid1().hex self.createAPI(api(self.createlogicalnetwork, self.apiroutine), api(self.createlogicalnetworks, self.apiroutine), api(self.createphysicalnetwork, self.apiroutine), api(self.createphysicalnetworks, self.apiroutine), api(self.createphysicalport, self.apiroutine), api(self.createphysicalports, self.apiroutine), api(self.createlogicalport, self.apiroutine), api(self.createlogicalports, self.apiroutine), api(self.getlogicalnetworks, self.apiroutine)) self._logger.setLevel(logging.DEBUG) def _monitor(self): update_event = DataObjectUpdateEvent.createMatcher() while True: yield (update_event,) self._logger.info('Database update: %r', self.apiroutine.event) def _dumpkeys(self, keys): self._reqid += 1 reqid = ('testobjectdb', self._reqid) for m in callAPI(self.apiroutine, 'objectdb', 'mget', {'keys': keys, 'requestid': reqid}): yield m retobjs = self.apiroutine.retvalue with watch_context(keys, retobjs, reqid, self.apiroutine): self.apiroutine.retvalue = [dump(v) for v in retobjs] def _updateport(self, key): unload_matcher = ModuleLoadStateChanged.createMatcher(self.target, ModuleLoadStateChanged.UNLOADING) def updateinner(): self._reqid += 1 reqid = ('testobjectdb', self._reqid) for m in callAPI(self.apiroutine, 'objectdb', 'get', {'key': key, 'requestid': reqid}): yield m portobj = self.apiroutine.retvalue with watch_context([key], [portobj], reqid, self.apiroutine): if portobj is not None: @updater def write_status(portobj): if portobj is None: raise ValueError('Already deleted') if not hasattr(portobj, 'owner'): portobj.owner = self._ownerid portobj.status = 'READY' return [portobj] else: raise ValueError('Already managed') try: for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [portobj.getkey()], 'updater': write_status}): yield m except ValueError: pass else: for m in portobj.waitif(self.apiroutine, lambda x: x.isdeleted() or hasattr(x, 'owner')): yield m self._logger.info('Port managed: %r', dump(portobj)) while True: for m in portobj.waitif(self.apiroutine, lambda x: True, True): yield m if portobj.isdeleted(): self._logger.info('Port deleted: %r', dump(portobj)) break else: self._logger.info('Port updated: %r', dump(portobj)) try: for m in self.apiroutine.withException(updateinner(), unload_matcher): yield m except RoutineException: pass def _waitforchange(self, key): for m in callAPI(self.apiroutine, 'objectdb', 'watch', {'key': key, 'requestid': 'testobjectdb'}): yield m setobj = self.apiroutine.retvalue with watch_context([key], [setobj], 'testobjectdb', self.apiroutine): for m in setobj.wait(self.apiroutine): yield m oldset = set() while True: for weakref in setobj.set.dataset().difference(oldset): self.apiroutine.subroutine(self._updateport(weakref.getkey())) oldset = set(setobj.set.dataset()) for m in setobj.waitif(self.apiroutine, lambda x: not x.isdeleted(), True): yield m def _main(self): routines = [] routines.append(self._monitor()) keys = [LogicalPortSet.default_key(), PhysicalPortSet.default_key()] for k in keys: routines.append(self._waitforchange(k)) for m in self.apiroutine.executeAll(routines, retnames = ()): yield m def load(self, container): @updater def initialize(phynetset, lognetset, logportset, phyportset): if phynetset is None: phynetset = PhysicalNetworkSet() phynetset.set = DataObjectSet() if lognetset is None: lognetset = LogicalNetworkSet() lognetset.set = DataObjectSet() if logportset is None: logportset = LogicalPortSet() logportset.set = DataObjectSet() if phyportset is None: phyportset = PhysicalPortSet() phyportset.set = DataObjectSet() return [phynetset, lognetset, logportset, phyportset] for m in callAPI(container, 'objectdb', 'transact', {'keys':[PhysicalNetworkSet.default_key(), LogicalNetworkSet.default_key(), LogicalPortSet.default_key(), PhysicalPortSet.default_key()], 'updater': initialize}): yield m for m in Module.load(self, container): yield m def createphysicalnetwork(self, type = 'vlan', id = None, **kwargs): new_network, new_map = self._createphysicalnetwork(type, id, **kwargs) @updater def create_phy(physet, phynet, phymap): phynet = set_new(phynet, new_network) phymap = set_new(phymap, new_map) physet.set.dataset().add(phynet.create_weakreference()) return [physet, phynet, phymap] for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys':[PhysicalNetworkSet.default_key(), new_network.getkey(), new_map.getkey()],'updater':create_phy}): yield m for m in self._dumpkeys([new_network.getkey()]): yield m self.apiroutine.retvalue = self.apiroutine.retvalue[0] def createphysicalnetworks(self, networks): new_networks = [self._createphysicalnetwork(**n) for n in networks] @updater def create_phys(physet, *phynets): return_nets = [None, None] * len(new_networks) for i in range(0, len(new_networks)): return_nets[i * 2] = set_new(phynets[i * 2], new_networks[i][0]) return_nets[i * 2 + 1] = set_new(phynets[i * 2 + 1], new_networks[i][1]) physet.set.dataset().add(new_networks[i][0].create_weakreference()) return [physet] + return_nets keys = [sn.getkey() for n in new_networks for sn in n] for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys':[PhysicalNetworkSet.default_key()] + keys,'updater':create_phys}): yield m for m in self._dumpkeys([n[0].getkey() for n in new_networks]): yield m def _createlogicalnetwork(self, physicalnetwork, id = None, **kwargs): if not id: id = str(uuid1()) new_network = LogicalNetwork.create_instance(id) for k,v in kwargs.items(): setattr(new_network, k, v) new_network.physicalnetwork = ReferenceObject(PhysicalNetwork.default_key(physicalnetwork)) new_networkmap = LogicalNetworkMap.create_instance(id) new_networkmap.network = new_network.create_reference() return new_network,new_networkmap def createlogicalnetworks(self, networks): new_networks = [self._createlogicalnetwork(**n) for n in networks] physical_networks = list(set(n[0].physicalnetwork.getkey() for n in new_networks)) physical_maps = [PhysicalNetworkMap.default_key(PhysicalNetwork._getIndices(k)[1][0]) for k in physical_networks] @updater def create_logs(logset, *networks): phy_maps = list(networks[len(new_networks) * 2 : len(new_networks) * 2 + len(physical_networks)]) phy_nets = list(networks[len(new_networks) * 2 + len(physical_networks):]) phy_dict = dict(zip(physical_networks, zip(phy_nets, phy_maps))) return_nets = [None, None] * len(new_networks) for i in range(0, len(new_networks)): return_nets[2 * i] = set_new(networks[2 * i], new_networks[i][0]) return_nets[2 * i + 1] = set_new(networks[2 * i + 1], new_networks[i][1]) for n in return_nets[::2]: phynet, phymap = phy_dict.get(n.physicalnetwork.getkey()) if phynet is None: _, (phyid,) = PhysicalNetwork._getIndices(n.physicalnetwork.getkey()) raise ValueError('Physical network %r does not exist' % (phyid,)) else: if phynet.type == 'vlan': if hasattr(n, 'vlanid'): n.vlanid = int(n.vlanid) if n.vlanid <= 0 or n.vlanid >= 4095: raise ValueError('Invalid VLAN ID') # VLAN id is specified if str(n.vlanid) in phymap.network_allocation: raise ValueError('VLAN ID %r is already allocated in physical network %r' % (n.vlanid,phynet.id)) else: for start,end in phynet.vlanrange: if start <= n.vlanid <= end: break else: raise ValueError('VLAN ID %r is not in vlan range of physical network %r' % (n.vlanid,phynet.id)) phymap.network_allocation[str(n.vlanid)] = n.create_weakreference() else: # Allocate a new VLAN id for start,end in phynet.vlanrange: for vlanid in range(start, end + 1): if str(vlanid) not in phymap.network_allocation: break else: continue break else: raise ValueError('Not enough VLAN ID to be allocated in physical network %r' % (phynet.id,)) n.vlanid = vlanid phymap.network_allocation[str(vlanid)] = n.create_weakreference() else: if phymap.network_allocation: raise ValueError('Physical network %r is already allocated by another logical network', (phynet.id,)) phymap.network_allocation['native'] = n.create_weakreference() phymap.networks.dataset().add(n.create_weakreference()) logset.set.dataset().add(n.create_weakreference()) return [logset] + return_nets + phy_maps for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [LogicalNetworkSet.default_key()] +\ [sn.getkey() for n in new_networks for sn in n] +\ physical_maps +\ physical_networks, 'updater': create_logs}): yield m for m in self._dumpkeys([n[0].getkey() for n in new_networks]): yield m def createlogicalnetwork(self, physicalnetwork, id = None, **kwargs): n = {'physicalnetwork':physicalnetwork, 'id':id} n.update(kwargs) for m in self.createlogicalnetworks([n]): yield m self.apiroutine.retvalue = self.apiroutine.retvalue[0] def _createphysicalnetwork(self, type = 'vlan', id = None, **kwargs): if not id: id = str(uuid1()) if type == 'vlan': if 'vlanrange' not in kwargs: raise ValueError(r'Must specify vlanrange with network type="vlan"') vlanrange = kwargs['vlanrange'] # Check try: lastend = 0 for start, end in vlanrange: if start <= lastend: raise ValueError('VLAN sequences overlapped or disordered') lastend = end if lastend >= 4095: raise ValueError('VLAN ID out of range') except Exception as exc: raise ValueError('vlanrange format error: %s' % (str(exc),)) else: type = 'native' new_network = PhysicalNetwork.create_instance(id) new_network.type = type for k,v in kwargs.items(): setattr(new_network, k, v) new_networkmap = PhysicalNetworkMap.create_instance(id) new_networkmap.network = new_network.create_reference() return (new_network, new_networkmap) def createphysicalport(self, physicalnetwork, name, systemid = '%', bridge = '%', **kwargs): p = {'physicalnetwork':physicalnetwork, 'name':name, 'systemid':systemid,'bridge':bridge} p.update(kwargs) for m in self.createphysicalports([p]): yield m self.apiroutine.retvalue = self.apiroutine.retvalue[0] def _createphysicalport(self, physicalnetwork, name, systemid = '%', bridge = '%', **kwargs): new_port = PhysicalPort.create_instance(systemid, bridge, name) new_port.physicalnetwork = ReferenceObject(PhysicalNetwork.default_key(physicalnetwork)) for k,v in kwargs.items(): setattr(new_port, k, v) return new_port def createphysicalports(self, ports): new_ports = [self._createphysicalport(**p) for p in ports] physical_networks = list(set([p.physicalnetwork.getkey() for p in new_ports])) physical_maps = [PhysicalNetworkMap.default_key(*PhysicalNetwork._getIndices(k)[1]) for k in physical_networks] @updater def create_ports(portset, *objs): old_ports = objs[:len(new_ports)] phymaps = list(objs[len(new_ports):len(new_ports) + len(physical_networks)]) phynets = list(objs[len(new_ports) + len(physical_networks):]) phydict = dict(zip(physical_networks, zip(phynets, phymaps))) return_ports = [None] * len(new_ports) for i in range(0, len(new_ports)): return_ports[i] = set_new(old_ports[i], new_ports[i]) for p in return_ports: phynet, phymap = phydict[p.physicalnetwork.getkey()] if phynet is None: _, (phyid,) = PhysicalNetwork._getIndices(p.physicalnetwork.getkey()) raise ValueError('Physical network %r does not exist' % (phyid,)) phymap.ports.dataset().add(p.create_weakreference()) portset.set.dataset().add(p.create_weakreference()) return [portset] + return_ports + phymaps for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [PhysicalPortSet.default_key()] +\ [p.getkey() for p in new_ports] +\ physical_maps +\ physical_networks, 'updater': create_ports}): yield m for m in self._dumpkeys([p.getkey() for p in new_ports]): yield m def createlogicalport(self, logicalnetwork, id = None, **kwargs): p = {'logicalnetwork':logicalnetwork, 'id':id} p.update(kwargs) for m in self.createlogicalports([p]): yield m self.apiroutine.retvalue = self.apiroutine.retvalue[0] def _createlogicalport(self, logicalnetwork, id = None, **kwargs): if not id: id = str(uuid1()) new_port = LogicalPort.create_instance(id) new_port.logicalnetwork = ReferenceObject(LogicalNetwork.default_key(logicalnetwork)) for k,v in kwargs.items(): setattr(new_port, k, v) return new_port def createlogicalports(self, ports): new_ports = [self._createlogicalport(**p) for p in ports] logical_networks = list(set([p.logicalnetwork.getkey() for p in new_ports])) logical_maps = [LogicalNetworkMap.default_key(*LogicalNetwork._getIndices(k)[1]) for k in logical_networks] @updater def create_ports(portset, *objs): old_ports = objs[:len(new_ports)] logmaps = list(objs[len(new_ports):len(new_ports) + len(logical_networks)]) lognets = list(objs[len(new_ports) + len(logical_networks):]) logdict = dict(zip(logical_networks, zip(lognets, logmaps))) return_ports = [None] * len(new_ports) for i in range(0, len(new_ports)): return_ports[i] = set_new(old_ports[i], new_ports[i]) for p in return_ports: lognet, logmap = logdict[p.logicalnetwork.getkey()] if lognet is None: _, (logid,) = LogicalNetwork._getIndices(p.logicalnetwork.getkey()) raise ValueError('Logical network %r does not exist' % (logid,)) logmap.ports.dataset().add(p.create_weakreference()) portset.set.dataset().add(p.create_weakreference()) return [portset] + return_ports + logmaps for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [LogicalPortSet.default_key()] +\ [p.getkey() for p in new_ports] +\ logical_maps +\ logical_networks, 'updater': create_ports}): yield m for m in self._dumpkeys([p.getkey() for p in new_ports]): yield m def getlogicalnetworks(self, id = None, physicalnetwork = None, **kwargs): def set_walker(key, set, walk, save): if set is None: return for o in set.dataset(): key = o.getkey() try: net = walk(key) except KeyError: pass else: for k,v in kwargs.items(): if getattr(net, k, None) != v: break else: save(key) def walker_func(set_func): def walker(key, obj, walk, save): if obj is None: return set_walker(key, set_func(obj), walk, save) return walker if id is not None: self._reqid += 1 reqid = ('testobjectdb', self._reqid) for m in callAPI(self.apiroutine, 'objectdb', 'get', {'key' : LogicalNetwork.default_key(id), 'requestid': reqid}): yield m result = self.apiroutine.retvalue with watch_context([LogicalNetwork.default_key(id)], [result], reqid, self.apiroutine): if result is None: self.apiroutine.retvalue = [] return if physicalnetwork is not None and physicalnetwork != result.physicalnetwork.id: self.apiroutine.retvalue = [] return for k,v in kwargs.items(): if getattr(result, k, None) != v: self.apiroutine.retvalue = [] return self.apiroutine.retvalue = [dump(result)] elif physicalnetwork is not None: self._reqid += 1 reqid = ('testobjectdb', self._reqid) pm_key = PhysicalNetworkMap.default_key(physicalnetwork) for m in callAPI(self.apiroutine, 'objectdb', 'walk', {'keys': [pm_key], 'walkerdict': {pm_key: walker_func(lambda x: x.networks)}, 'requestid': reqid}): yield m keys, result = self.apiroutine.retvalue with watch_context(keys, result, reqid, self.apiroutine): self.apiroutine.retvalue = [dump(r) for r in result] else: self._reqid += 1 reqid = ('testobjectdb', self._reqid) ns_key = LogicalNetworkSet.default_key() for m in callAPI(self.apiroutine, 'objectdb', 'walk', {'keys': [ns_key], 'walkerdict': {ns_key: walker_func(lambda x: x.set)}, 'requestid': reqid}): yield m keys, result = self.apiroutine.retvalue with watch_context(keys, result, reqid, self.apiroutine): self.apiroutine.retvalue = [dump(r) for r in result]
class OVSDBManager(Module): ''' Manage Openflow Connections ''' service = True _default_vhostbind = None _default_bridgenames = None def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_conns self.routines.append(self.apiroutine) self.managed_conns = {} self.managed_systemids = {} self.managed_bridges = {} self.managed_routines = [] self.endpoint_conns = {} self.createAPI(api(self.getconnection, self.apiroutine), api(self.waitconnection, self.apiroutine), api(self.getdatapathids, self.apiroutine), api(self.getalldatapathids, self.apiroutine), api(self.getallconnections, self.apiroutine), api(self.getbridges, self.apiroutine), api(self.getbridge, self.apiroutine), api(self.getbridgebyuuid, self.apiroutine), api(self.waitbridge, self.apiroutine), api(self.waitbridgebyuuid, self.apiroutine), api(self.getsystemids, self.apiroutine), api(self.getallsystemids, self.apiroutine), api(self.getconnectionbysystemid, self.apiroutine), api(self.waitconnectionbysystemid, self.apiroutine), api(self.getconnectionsbyendpoint, self.apiroutine), api(self.getconnectionsbyendpointname, self.apiroutine), api(self.getendpoints, self.apiroutine), api(self.getallendpoints, self.apiroutine), api(self.getallbridges, self.apiroutine), api(self.getbridgeinfo, self.apiroutine), api(self.waitbridgeinfo, self.apiroutine) ) self._synchronized = False def _update_bridge(self, connection, protocol, bridge_uuid, vhost): try: method, params = ovsdb.transact('Open_vSwitch', ovsdb.wait('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]], ["datapath_id"], [{"datapath_id": ovsdb.oset()}], False, 5000), ovsdb.select('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]], ["datapath_id","name"])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[1] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error'])) if r['rows']: r0 = r['rows'][0] name = r0['name'] dpid = int(r0['datapath_id'], 16) if self.bridgenames is None or name in self.bridgenames: self.managed_bridges[connection].append((vhost, dpid, name, bridge_uuid)) self.managed_conns[(vhost, dpid)] = connection for m in self.apiroutine.waitForSend(OVSDBBridgeSetup(OVSDBBridgeSetup.UP, dpid, connection.ovsdb_systemid, name, connection, connection.connmark, vhost, bridge_uuid)): yield m except JsonRPCProtocolException: pass def _get_bridges(self, connection, protocol): try: try: vhost = protocol.vhost if not hasattr(connection, 'ovsdb_systemid'): method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Open_vSwitch', [], ['external_ids'])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m result = self.apiroutine.jsonrpc_result[0] system_id = ovsdb.omap_getvalue(result['rows'][0]['external_ids'], 'system-id') connection.ovsdb_systemid = system_id else: system_id = connection.ovsdb_systemid if (vhost, system_id) in self.managed_systemids: oc = self.managed_systemids[(vhost, system_id)] ep = _get_endpoint(oc) econns = self.endpoint_conns.get((vhost, ep)) if econns: try: econns.remove(oc) except ValueError: pass del self.managed_systemids[(vhost, system_id)] self.managed_systemids[(vhost, system_id)] = connection self.managed_bridges[connection] = [] ep = _get_endpoint(connection) self.endpoint_conns.setdefault((vhost, ep), []).append(connection) method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])}) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: # The monitor is already set, cancel it first method, params = ovsdb.monitor_cancel('ovsdb_manager_bridges_monitor') for m in protocol.querywithreply(method, params, connection, self.apiroutine, False): yield m method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])}) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result)) except Exception: for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m raise else: # Process initial bridges init_subprocesses = [self._update_bridge(connection, protocol, buuid, vhost) for buuid in self.apiroutine.jsonrpc_result['Bridge'].keys()] def init_process(): try: with closing(self.apiroutine.executeAll(init_subprocesses, retnames = ())) as g: for m in g: yield m except Exception: for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m raise else: for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)): yield m self.apiroutine.subroutine(init_process()) # Wait for notify notification = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark, _ismatch = lambda x: x.params[0] == 'ovsdb_manager_bridges_monitor') conn_down = protocol.statematcher(connection) while True: yield (conn_down, notification) if self.apiroutine.matcher is conn_down: break else: for buuid, v in self.apiroutine.event.params[1]['Bridge'].items(): # If a bridge's name or datapath-id is changed, we remove this bridge and add it again if 'old' in v: # A bridge is deleted bridges = self.managed_bridges[connection] for i in range(0, len(bridges)): if buuid == bridges[i][3]: self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, bridges[i][1], system_id, bridges[i][2], connection, connection.connmark, vhost, bridges[i][3], new_datapath_id = int(v['new']['datapath_id'], 16) if 'new' in v and 'datapath_id' in v['new'] else None)) del self.managed_conns[(vhost, bridges[i][1])] del bridges[i] break if 'new' in v: # A bridge is added self.apiroutine.subroutine(self._update_bridge(connection, protocol, buuid, vhost)) except JsonRPCProtocolException: pass finally: del connection._ovsdb_manager_get_bridges def _manage_existing(self): for m in callAPI(self.apiroutine, "jsonrpcserver", "getconnections", {}): yield m vb = self.vhostbind conns = self.apiroutine.retvalue for c in conns: if vb is None or c.protocol.vhost in vb: if not hasattr(c, '_ovsdb_manager_get_bridges'): c._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(c, c.protocol)) matchers = [OVSDBConnectionSetup.createMatcher(None, c, c.connmark) for c in conns] for m in self.apiroutine.waitForAll(*matchers): yield m self._synchronized = True for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),) def _manage_conns(self): try: self.apiroutine.subroutine(self._manage_existing()) vb = self.vhostbind if vb is not None: conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP, _ismatch = lambda x: x.createby.vhost in vb) conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN, _ismatch = lambda x: x.createby.vhost in vb) else: conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP) conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN) while True: yield (conn_up, conn_down) if self.apiroutine.matcher is conn_up: if not hasattr(self.apiroutine.event.connection, '_ovsdb_manager_get_bridges'): self.apiroutine.event.connection._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(self.apiroutine.event.connection, self.apiroutine.event.createby)) else: e = self.apiroutine.event conn = e.connection bridges = self.managed_bridges.get(conn) if bridges is not None: del self.managed_systemids[(e.createby.vhost, conn.ovsdb_systemid)] del self.managed_bridges[conn] for vhost, dpid, name, buuid in bridges: del self.managed_conns[(vhost, dpid)] self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid, conn.ovsdb_systemid, name, conn, conn.connmark, e.createby.vhost, buuid)) econns = self.endpoint_conns.get(_get_endpoint(conn)) if econns is not None: try: econns.remove(conn) except ValueError: pass finally: for c in self.managed_bridges.keys(): if hasattr(c, '_ovsdb_manager_get_bridges'): c._ovsdb_manager_get_bridges.close() bridges = self.managed_bridges.get(c) if bridges is not None: for vhost, dpid, name, buuid in bridges: del self.managed_conns[(vhost, dpid)] self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid, c.ovsdb_systemid, name, c, c.connmark, c.protocol.vhost, buuid)) def getconnection(self, datapathid, vhost = ''): "Get current connection of datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.managed_conns.get((vhost, datapathid)) def waitconnection(self, datapathid, timeout = 30, vhost = ''): "Wait for a datapath connection" for m in self.getconnection(datapathid, vhost): yield m c = self.apiroutine.retvalue if c is None: for m in self.apiroutine.waitWithTimeout(timeout, OVSDBBridgeSetup.createMatcher( state = OVSDBBridgeSetup.UP, datapathid = datapathid, vhost = vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath is not connected') self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getdatapathids(self, vhost = ''): "Get All datapath IDs" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost] def getalldatapathids(self): "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_conns.keys()) def getallconnections(self, vhost = ''): "Get all connections from vhost. If vhost is None, return all connections from any host" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = list(self.managed_bridges.keys()) else: self.apiroutine.retvalue = list(k for k in self.managed_bridges.keys() if k.protocol.vhost == vhost) def getbridges(self, connection): "Get all (dpid, name, _uuid) tuple on this connection" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: self.apiroutine.retvalue = [(dpid, name, buuid) for _, dpid, name, buuid in bridges] else: self.apiroutine.retvalue = None def getallbridges(self, vhost = None): "Get all (dpid, name, _uuid) tuple for all connections, optionally filtered by vhost" for m in self._wait_for_sync(): yield m if vhost is not None: self.apiroutine.retvalue = [(dpid, name, buuid) for c, bridges in self.managed_bridges.items() if c.protocol.vhost == vhost for _, dpid, name, buuid in bridges] else: self.apiroutine.retvalue = [(dpid, name, buuid) for c, bridges in self.managed_bridges.items() for _, dpid, name, buuid in bridges] def getbridge(self, connection, name): "Get datapath ID on this connection with specified name" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: for _, dpid, n, _ in bridges: if n == name: self.apiroutine.retvalue = dpid return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridge(self, connection, name, timeout = 30): "Wait for bridge with specified name appears and return the datapath-id" bnames = self.bridgenames if bnames is not None and name not in bnames: raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear: it is not in the selected bridge names') for m in self.getbridge(connection, name): yield m if self.apiroutine.retvalue is None: bridge_setup = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.UP, None, None, name, connection ) conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection, connection.connmark) for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear') elif self.apiroutine.matcher is conn_down: raise ConnectionResetException('Connection is down before bridge ' + repr(name) + ' appears') else: self.apiroutine.retvalue = self.apiroutine.event.datapathid def getbridgebyuuid(self, connection, uuid): "Get datapath ID of bridge on this connection with specified _uuid" for m in self._wait_for_sync(): yield m bridges = self.managed_bridges.get(connection) if bridges is not None: for _, dpid, _, buuid in bridges: if buuid == uuid: self.apiroutine.retvalue = dpid return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridgebyuuid(self, connection, uuid, timeout = 30): "Wait for bridge with specified _uuid appears and return the datapath-id" for m in self.getbridgebyuuid(connection, uuid): yield m if self.apiroutine.retvalue is None: bridge_setup = OVSDBBridgeSetup.createMatcher(state = OVSDBBridgeSetup.UP, connection = connection, bridgeuuid = uuid ) conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection, connection.connmark) for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge ' + repr(uuid) + ' does not appear') elif self.apiroutine.matcher is conn_down: raise ConnectionResetException('Connection is down before bridge ' + repr(uuid) + ' appears') else: self.apiroutine.retvalue = self.apiroutine.event.datapathid def getsystemids(self, vhost = ''): "Get All system-ids" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.managed_systemids.keys() if k[0] == vhost] def getallsystemids(self): "Get all system-ids from any vhost. Return (vhost, system-id) pair." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.managed_systemids.keys()) def getconnectionbysystemid(self, systemid, vhost = ''): for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.managed_systemids.get((vhost, systemid)) def waitconnectionbysystemid(self, systemid, timeout = 30, vhost = ''): "Wait for a connection with specified system-id" for m in self.getconnectionbysystemid(systemid, vhost): yield m c = self.apiroutine.retvalue if c is None: for m in self.apiroutine.waitWithTimeout(timeout, OVSDBConnectionSetup.createMatcher( systemid, None, None, vhost)): yield m if self.apiroutine.timeout: raise ConnectionResetException('Datapath is not connected') self.apiroutine.retvalue = self.apiroutine.event.connection else: self.apiroutine.retvalue = c def getconnectionsbyendpoint(self, endpoint, vhost = ''): "Get connection by endpoint address (IP, IPv6 or UNIX socket address)" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint)) def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30): "Get connection by endpoint name (Domain name, IP or IPv6 address)" # Resolve the name if not name: endpoint = '' for m in self.getconnectionbyendpoint(endpoint, vhost): yield m else: request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) # Resolve hostname for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)): yield m for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)): yield m if self.apiroutine.timeout: # Resolve is only allowed through asynchronous resolver #try: # self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST) #except: raise IOError('Resolve hostname timeout: ' + name) else: if hasattr(self.apiroutine.event, 'error'): raise IOError('Cannot resolve hostname: ' + name) resp = self.apiroutine.event.response for r in resp: raddr = r[4] if isinstance(raddr, tuple): # Ignore port endpoint = raddr[0] else: # Unix socket? This should not happen, but in case... endpoint = raddr for m in self.getconnectionsbyendpoint(endpoint, vhost): yield m if self.apiroutine.retvalue is not None: break def getendpoints(self, vhost = ''): "Get all endpoints for vhost" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost] def getallendpoints(self): "Get all endpoints from any vhost. Return (vhost, endpoint) pairs." for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = list(self.endpoint_conns.keys()) def getbridgeinfo(self, datapathid, vhost = ''): "Get (bridgename, systemid, bridge_uuid) tuple from bridge datapathid" for m in self.getconnection(datapathid, vhost): yield m if self.apiroutine.retvalue is not None: c = self.apiroutine.retvalue bridges = self.managed_bridges.get(c) if bridges is not None: for _, dpid, n, buuid in bridges: if dpid == datapathid: self.apiroutine.retvalue = (n, c.ovsdb_systemid, buuid) return self.apiroutine.retvalue = None else: self.apiroutine.retvalue = None def waitbridgeinfo(self, datapathid, timeout = 30, vhost = ''): "Wait for bridge with datapathid, and return (bridgename, systemid, bridge_uuid) tuple" for m in self.getbridgeinfo(datapathid, vhost): yield m if self.apiroutine.retvalue is None: for m in self.apiroutine.waitWithTimeout(timeout, OVSDBBridgeSetup.createMatcher( OVSDBBridgeSetup.UP, datapathid, None, None, None, None, vhost)): yield m if self.apiroutine.timeout: raise OVSDBBridgeNotAppearException('Bridge 0x%016x does not appear before timeout' % (datapathid,)) e = self.apiroutine.event self.apiroutine.retvalue = (e.name, e.systemid, e.bridgeuuid)
class OVSDBPortManager(Module): ''' Manage Ports from OVSDB Protocol ''' service = True def __init__(self, server): Module.__init__(self, server) self.apiroutine = RoutineContainer(self.scheduler) self.apiroutine.main = self._manage_ports self.routines.append(self.apiroutine) self.managed_ports = {} self.managed_ids = {} self.monitor_routines = set() self.ports_uuids = {} self.wait_portnos = {} self.wait_names = {} self.wait_ids = {} self.bridge_datapathid = {} self.createAPI(api(self.getports, self.apiroutine), api(self.getallports, self.apiroutine), api(self.getportbyid, self.apiroutine), api(self.waitportbyid, self.apiroutine), api(self.getportbyname, self.apiroutine), api(self.waitportbyname, self.apiroutine), api(self.getportbyno, self.apiroutine), api(self.waitportbyno, self.apiroutine), api(self.resync, self.apiroutine) ) self._synchronized = False def _get_interface_info(self, connection, protocol, buuid, interface_uuid, port_uuid): try: method, params = ovsdb.transact('Open_vSwitch', ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ofport"], [{"ofport":ovsdb.oset()}], False, 5000), ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ifindex"], [{"ifindex":ovsdb.oset()}], False, 5000), ovsdb.select('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["_uuid", "name", "ifindex", "ofport", "type", "external_ids"])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[1] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error'])) r = self.apiroutine.jsonrpc_result[2] if 'error' in r: raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error'])) if not r['rows']: self.apiroutine.retvalue = [] return r0 = r['rows'][0] if r0['ofport'] < 0: # Ignore this port because it is in an error state self.apiroutine.retvalue = [] return r0['_uuid'] = r0['_uuid'][1] r0['ifindex'] = ovsdb.getoptional(r0['ifindex']) r0['external_ids'] = ovsdb.getdict(r0['external_ids']) if buuid not in self.bridge_datapathid: self.apiroutine.retvalue = [] return else: datapath_id = self.bridge_datapathid[buuid] if 'iface-id' in r0['external_ids']: eid = r0['external_ids']['iface-id'] r0['id'] = eid id_ports = self.managed_ids.setdefault((protocol.vhost, eid), []) id_ports.append((datapath_id, r0)) else: r0['id'] = None self.managed_ports.setdefault((protocol.vhost, datapath_id),[]).append((port_uuid, r0)) notify = False if (protocol.vhost, datapath_id, r0['ofport']) in self.wait_portnos: notify = True del self.wait_portnos[(protocol.vhost, datapath_id, r0['ofport'])] if (protocol.vhost, datapath_id, r0['name']) in self.wait_names: notify = True del self.wait_names[(protocol.vhost, datapath_id, r0['name'])] if (protocol.vhost, r0['id']) in self.wait_ids: notify = True del self.wait_ids[(protocol.vhost, r0['id'])] if notify: for m in self.apiroutine.waitForSend(OVSDBPortUpNotification(connection, r0['name'], r0['ofport'], r0['id'], protocol.vhost, datapath_id, port = r0)): yield m self.apiroutine.retvalue = [r0] except JsonRPCProtocolException: self.apiroutine.retvalue = [] def _remove_interface_id(self, connection, protocol, datapath_id, port): eid = port['id'] eid_list = self.managed_ids.get((protocol.vhost, eid)) for i in range(0, len(eid_list)): if eid_list[i][1]['_uuid'] == port['_uuid']: del eid_list[i] break def _remove_interface(self, connection, protocol, datapath_id, interface_uuid, port_uuid): ports = self.managed_ports.get((protocol.vhost, datapath_id)) r = None if ports is not None: for i in range(0, len(ports)): if ports[i][1]['_uuid'] == interface_uuid: r = ports[i][1] if r['id']: self._remove_interface_id(connection, protocol, datapath_id, r) del ports[i] break if not ports: del self.managed_ports[(protocol.vhost, datapath_id)] return r def _remove_all_interface(self, connection, protocol, datapath_id, port_uuid, buuid): ports = self.managed_ports.get((protocol.vhost, datapath_id)) if ports is not None: removed_ports = [r for puuid, r in ports if puuid == port_uuid] not_removed_ports = [(puuid, r) for puuid, r in ports if puuid != port_uuid] ports[:len(not_removed_ports)] = not_removed_ports del ports[len(not_removed_ports):] for r in removed_ports: if r['id']: self._remove_interface_id(connection, protocol, datapath_id, r) if not ports: del self.managed_ports[(protocol.vhost, datapath_id)] return removed_ports if port_uuid in self.ports_uuids and self.ports_uuids[port_uuid] == buuid: del self.ports_uuids[port_uuid] return [] def _update_interfaces(self, connection, protocol, updateinfo, update = True): """ There are several kinds of updates, they may appear together: 1. New bridge created (or from initial updateinfo). We should add all the interfaces to the list. 2. Bridge removed. Remove all the ports. 3. name and datapath_id may be changed. We will consider this as a new bridge created, and an old bridge removed. 4. Bridge ports modification, i.e. add/remove ports. a) Normally a port record is created/deleted together. A port record cannot exist without a bridge containing it. b) It is also possible that a port is removed from one bridge and added to another bridge, in this case the ports do not appear in the updateinfo 5. Port interfaces modification, i.e. add/remove interfaces. The bridge record may not appear in this situation. We must consider these situations carefully and process them in correct order. """ port_update = updateinfo.get('Port', {}) bridge_update = updateinfo.get('Bridge', {}) working_routines = [] def process_bridge(buuid, uo): try: nv = uo['new'] if 'datapath_id' in nv: datapath_id = int(nv['datapath_id'], 16) self.bridge_datapathid[buuid] = datapath_id elif buuid in self.bridge_datapathid: datapath_id = self.bridge_datapathid[buuid] else: # This should not happen, but just in case... for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitbridge', {'connection': connection, 'name': nv['name'], 'timeout': 5}): yield m datapath_id = self.apiroutine.retvalue if 'ports' in nv: nset = set((p for _,p in ovsdb.getlist(nv['ports']))) else: nset = set() if 'old' in uo: ov = uo['old'] if 'ports' in ov: oset = set((p for _,p in ovsdb.getlist(ov['ports']))) else: # new ports are not really added; it is only sent because datapath_id is modified nset = set() oset = set() if 'datapath_id' in ov: old_datapathid = int(ov['datapath_id'], 16) else: old_datapathid = datapath_id else: oset = set() old_datapathid = datapath_id # For every deleted port, remove the interfaces with this port _uuid remove = [] add_routine = [] for puuid in oset - nset: remove += self._remove_all_interface(connection, protocol, old_datapathid, puuid, buuid) # For every port not changed, check if the interfaces are modified; for puuid in oset.intersection(nset): if puuid in port_update: # The port is modified, there should be an 'old' set and 'new' set pu = port_update[puuid] if 'old' in pu: poset = set((p for _,p in ovsdb.getlist(pu['old']['interfaces']))) else: poset = set() if 'new' in pu: pnset = set((p for _,p in ovsdb.getlist(pu['new']['interfaces']))) else: pnset = set() # Remove old interfaces remove += [r for r in (self._remove_interface(connection, protocol, datapath_id, iuuid, puuid) for iuuid in (poset - pnset)) if r is not None] # Prepare to add new interfaces add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for iuuid in (pnset - poset)] # For every port added, add the interfaces def add_port_interfaces(puuid): # If the uuid does not appear in update info, we have no choice but to query interfaces with select # we cannot use data from other bridges; the port may be moved from a bridge which is not tracked try: method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Port', [["_uuid", "==", ovsdb.uuid(puuid)]], ["interfaces"])) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m r = self.apiroutine.jsonrpc_result[0] if 'error' in r: raise JsonRPCErrorResultException('Error when query interfaces from port ' + repr(puuid) + ': ' + r['error']) if r['rows']: interfaces = ovsdb.getlist(r['rows'][0]['interfaces']) with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for _,iuuid in interfaces])) as g: for m in g: yield m self.apiroutine.retvalue = list(itertools.chain(r[0] for r in self.apiroutine.retvalue)) else: self.apiroutine.retvalue = [] except JsonRPCProtocolException: self.apiroutine.retvalue = [] except ConnectionResetException: self.apiroutine.retvalue = [] for puuid in nset - oset: self.ports_uuids[puuid] = buuid if puuid in port_update and 'new' in port_update[puuid] \ and 'old' not in port_update[puuid]: # Add all the interfaces in 'new' interfaces = ovsdb.getlist(port_update[puuid]['new']['interfaces']) add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid) for _,iuuid in interfaces] else: add_routine.append(add_port_interfaces(puuid)) # Execute the add_routine try: with closing(self.apiroutine.executeAll(add_routine)) as g: for m in g: yield m except: add = [] raise else: add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue)) finally: if update: self.scheduler.emergesend( ModuleNotification(self.getServiceName(), 'update', datapathid = datapath_id, connection = connection, vhost = protocol.vhost, add = add, remove = remove, reason = 'bridgemodify' if 'old' in uo else 'bridgeup' )) except JsonRPCProtocolException: pass except ConnectionResetException: pass except OVSDBBridgeNotAppearException: pass ignore_ports = set() for buuid, uo in bridge_update.items(): # Bridge removals are ignored because we process OVSDBBridgeSetup event instead if 'old' in uo: if 'ports' in uo['old']: oset = set((puuid for _, puuid in ovsdb.getlist(uo['old']['ports']))) ignore_ports.update(oset) if 'new' not in uo: if buuid in self.bridge_datapathid: del self.bridge_datapathid[buuid] if 'new' in uo: # If bridge contains this port is updated, we process the port update totally in bridge, # so we ignore it later if 'ports' in uo['new']: nset = set((puuid for _, puuid in ovsdb.getlist(uo['new']['ports']))) ignore_ports.update(nset) working_routines.append(process_bridge(buuid, uo)) def process_port(buuid, port_uuid, interfaces, remove_ids): if buuid not in self.bridge_datapathid: return datapath_id = self.bridge_datapathid[buuid] ports = self.managed_ports.get((protocol.vhost, datapath_id)) remove = [] if ports is not None: remove = [p for _,p in ports if p['_uuid'] in remove_ids] not_remove = [(_,p) for _,p in ports if p['_uuid'] not in remove_ids] ports[:len(not_remove)] = not_remove del ports[len(not_remove):] if interfaces: try: with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, port_uuid) for iuuid in interfaces])) as g: for m in g: yield m add = list(itertools.chain((r[0] for r in self.apiroutine.retvalue if r[0]))) except Exception: self._logger.warning("Cannot get new port information", exc_info = True) add = [] else: add = [] if update: for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = datapath_id, connection = connection, vhost = protocol.vhost, add = add, remove = remove, reason = 'bridgemodify' )): yield m for puuid, po in port_update.items(): if puuid not in ignore_ports: bridge_id = self.ports_uuids.get(puuid) if bridge_id is not None: datapath_id = self.bridge_datapathid[bridge_id] if datapath_id is not None: # This port is modified if 'new' in po: nset = set((iuuid for _, iuuid in ovsdb.getlist(po['new']['interfaces']))) else: nset = set() if 'old' in po: oset = set((iuuid for _, iuuid in ovsdb.getlist(po['old']['interfaces']))) else: oset = set() working_routines.append(process_port(bridge_id, puuid, nset - oset, oset - nset)) if update: for r in working_routines: self.apiroutine.subroutine(r) else: try: with closing(self.apiroutine.executeAll(working_routines, None, ())) as g: for m in g: yield m finally: self.scheduler.emergesend(OVSDBConnectionPortsSynchronized(connection)) def _get_ports(self, connection, protocol): try: try: method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_port_manager_interfaces_monitor', { 'Bridge':[ovsdb.monitor_request(["name", "datapath_id", "ports"])], 'Port':[ovsdb.monitor_request(["interfaces"])] }) for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: # The monitor is already set, cancel it first method2, params2 = ovsdb.monitor_cancel('ovsdb_port_manager_interfaces_monitor') for m in protocol.querywithreply(method2, params2, connection, self.apiroutine, False): yield m for m in protocol.querywithreply(method, params, connection, self.apiroutine): yield m if 'error' in self.apiroutine.jsonrpc_result: raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result)) r = self.apiroutine.jsonrpc_result except: for m in self.apiroutine.waitForSend(OVSDBConnectionPortsSynchronized(connection)): yield m raise # This is the initial state, it should contains all the ids of ports and interfaces self.apiroutine.subroutine(self._update_interfaces(connection, protocol, r, False)) update_matcher = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark, _ismatch = lambda x: x.params[0] == 'ovsdb_port_manager_interfaces_monitor') conn_state = protocol.statematcher(connection) while True: yield (update_matcher, conn_state) if self.apiroutine.matcher is conn_state: break else: self.apiroutine.subroutine(self._update_interfaces(connection, protocol, self.apiroutine.event.params[1], True)) except JsonRPCProtocolException: pass finally: if self.apiroutine.currentroutine in self.monitor_routines: self.monitor_routines.remove(self.apiroutine.currentroutine) def _get_existing_ports(self): for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getallconnections', {'vhost':None}): yield m matchers = [] for c in self.apiroutine.retvalue: self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(c, c.protocol))) matchers.append(OVSDBConnectionPortsSynchronized.createMatcher(c)) for m in self.apiroutine.waitForAll(*matchers): yield m self._synchronized = True for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')): yield m def _wait_for_sync(self): if not self._synchronized: yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),) def _manage_ports(self): try: self.apiroutine.subroutine(self._get_existing_ports()) connsetup = OVSDBConnectionSetup.createMatcher() bridgedown = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.DOWN) while True: yield (connsetup, bridgedown) e = self.apiroutine.event if self.apiroutine.matcher is connsetup: self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(e.connection, e.connection.protocol))) else: # Remove ports of the bridge ports = self.managed_ports.get((e.vhost, e.datapathid)) if ports is not None: ports_original = ports ports = [p for _,p in ports] for p in ports: if p['id']: self._remove_interface_id(e.connection, e.connection.protocol, e.datapathid, p) newdpid = getattr(e, 'new_datapath_id', None) buuid = e.bridgeuuid if newdpid is not None: # This bridge changes its datapath id if buuid in self.bridge_datapathid and self.bridge_datapathid[buuid] == e.datapathid: self.bridge_datapathid[buuid] = newdpid def re_add_interfaces(): with closing(self.apiroutine.executeAll( [self._get_interface_info(e.connection, e.connection.protocol, buuid, r['_uuid'], puuid) for puuid, r in ports_original])) as g: for m in g: yield m add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue)) for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = e.datapathid, connection = e.connection, vhost = e.vhost, add = add, remove = [], reason = 'bridgeup' )): yield m self.apiroutine.subroutine(re_add_interfaces()) else: # The ports are removed for puuid, _ in ports_original: if puuid in self.ports_uuids[puuid] and self.ports_uuids[puuid] == buuid: del self.ports_uuids[puuid] del self.managed_ports[(e.vhost, e.datapathid)] self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', datapathid = e.datapathid, connection = e.connection, vhost = e.vhost, add = [], remove = ports, reason = 'bridgedown' )) finally: for r in list(self.monitor_routines): r.close() self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized')) def getports(self, datapathid, vhost = ''): "Return all ports of a specifed datapath" for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = [p for _,p in self.managed_ports.get((vhost, datapathid), [])] def getallports(self, vhost = None): "Return all (datapathid, port, vhost) tuples, optionally filterd by vhost" for m in self._wait_for_sync(): yield m if vhost is None: self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() for _,p in v] else: self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() if vh == vhost for _,p in v] def getportbyno(self, datapathid, portno, vhost = ''): "Return port with specified portno" portno &= 0xffff for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost) def _getportbyno(self, datapathid, portno, vhost = ''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for _, p in ports: if p['ofport'] == portno: return p return None def waitportbyno(self, datapathid, portno, timeout = 30, vhost = ''): "Wait for port with specified portno" portno &= 0xffff for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyno(datapathid, portno, vhost) if p is not None: self.apiroutine.retvalue = p else: try: self.wait_portnos[(vhost, datapathid, portno)] = \ self.wait_portnos.get((vhost, datapathid, portno),0) + 1 yield (OVSDBPortUpNotification.createMatcher(None, None, portno, None, vhost, datapathid),) except: v = self.wait_portnos.get((vhost, datapathid, portno)) if v is not None: if v <= 1: del self.wait_portnos[(vhost, datapathid, portno)] else: self.wait_portnos[(vhost, datapathid, portno)] = v - 1 raise else: self.apiroutine.retvalue = self.apiroutine.event.port for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException('Port ' + repr(portno) + ' does not appear before timeout') def getportbyname(self, datapathid, name, vhost = ''): "Return port with specified name" if isinstance(name, bytes): name = name.decode('utf-8') for m in self._wait_for_sync(): yield m self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost) def _getportbyname(self, datapathid, name, vhost = ''): ports = self.managed_ports.get((vhost, datapathid)) if ports is None: return None else: for _, p in ports: if p['name'] == name: return p return None def waitportbyname(self, datapathid, name, timeout = 30, vhost = ''): "Wait for port with specified name" for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyname(datapathid, name, vhost) if p is not None: self.apiroutine.retvalue = p else: try: self.wait_names[(vhost, datapathid, name)] = \ self.wait_portnos.get((vhost, datapathid, name) ,0) + 1 yield (OVSDBPortUpNotification.createMatcher(None, name, None, None, vhost, datapathid),) except: v = self.wait_names.get((vhost, datapathid, name)) if v is not None: if v <= 1: del self.wait_names[(vhost, datapathid, name)] else: self.wait_names[(vhost, datapathid, name)] = v - 1 raise else: self.apiroutine.retvalue = self.apiroutine.event.port for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException('Port ' + repr(name) + ' does not appear before timeout') def getportbyid(self, id, vhost = ''): "Return port with the specified id. The return value is a pair: (datapath_id, port)" for m in self._wait_for_sync(): yield m self.apiroutine = self._getportbyid(id, vhost) def _getportbyid(self, id, vhost = ''): ports = self.managed_ids.get((vhost, id)) if ports: return ports[0] else: return None def waitportbyid(self, id, timeout = 30, vhost = ''): "Wait for port with the specified id. The return value is a pair (datapath_id, port)" for m in self._wait_for_sync(): yield m def waitinner(): p = self._getportbyid(id, vhost) if p is None: try: self.wait_ids[(vhost, id)] = self.wait_ids.get((vhost, id), 0) + 1 yield (OVSDBPortUpNotification.createMatcher(None, None, None, id, vhost),) except: v = self.wait_ids.get((vhost, id)) if v is not None: if v <= 1: del self.wait_ids[(vhost, id)] else: self.wait_ids[(vhost, id)] = v - 1 raise else: self.apiroutine.retvalue = (self.apiroutine.event.datapathid, self.apiroutine.event.port) else: self.apiroutine.retvalue = p for m in self.apiroutine.executeWithTimeout(timeout, waitinner()): yield m if self.apiroutine.timeout: raise OVSDBPortNotAppearException('Port ' + repr(id) + ' does not appear before timeout') def resync(self, datapathid, vhost = ''): ''' Resync with current ports ''' # Sometimes when the OVSDB connection is very busy, monitor message may be dropped. # We must deal with this and recover from it # Save current manged_ports if (vhost, datapathid) not in self.managed_ports: self.apiroutine.retvalue = None return else: for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getconnection', {'datapathid': datapathid, 'vhost':vhost}): yield m c = self.apiroutine.retvalue if c is not None: # For now, we restart the connection... for m in c.reconnect(False): yield m for m in self.apiroutine.waitWithTimeout(0.1): yield m for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitconnection', {'datapathid': datapathid, 'vhost': vhost}): yield m self.apiroutine.retvalue = None