Exemplo n.º 1
0
class TestModule(Module):
    _default_serverlist = ['tcp://localhost:3181/','tcp://localhost:3182/','tcp://localhost:3183/']
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self.main
        self.routines.append(self.apiroutine)
    def main(self):
        clients = [ZooKeeperClient(self.apiroutine, self.serverlist) for _ in range(0,10)]
        for c in clients:
            c.start()
        def test_loop(number):
            maindir = ('vlcptest_' + str(number)).encode('utf-8')
            client = clients[number % len(clients)]
            for _ in range(0, 100):
                for m in client.requests([zk.multi(
                                                zk.multi_create(maindir, b'test'),
                                                zk.multi_create(maindir + b'/subtest', 'test2')
                                            ),
                                          zk.getchildren2(maindir, True)], self.apiroutine):
                    yield m
                for m in client.requests([zk.multi(
                                                zk.multi_delete(maindir + b'/subtest'),
                                                zk.multi_delete(maindir)),
                                          zk.getchildren2(maindir, True)], self.apiroutine):
                    yield m
        from time import time
        starttime = time()
        for m in self.apiroutine.executeAll([test_loop(i) for i in range(0, 100)]):
            yield m
        print('10000 loops in %r seconds, with %d connections' % (time() - starttime, len(clients)))
        for c in clients:
            for m in c.shutdown():
                yield m
Exemplo n.º 2
0
class TestModule(Module):
    _default_serverlist = ['tcp://localhost:3181/','tcp://localhost:3182/','tcp://localhost:3183/']
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self.main
        self.routines.append(self.apiroutine)
    def main(self):
        clients = [ZooKeeperClient(self.apiroutine, self.serverlist) for _ in range(0,10)]
        for c in clients:
            c.start()
        def test_loop(number):
            maindir = ('vlcptest_' + str(number)).encode('utf-8')
            client = clients[number % len(clients)]
            for _ in range(0, 100):
                for m in client.requests([zk.multi(
                                                zk.multi_create(maindir, b'test'),
                                                zk.multi_create(maindir + b'/subtest', 'test2')
                                            ),
                                          zk.getchildren2(maindir, True)], self.apiroutine):
                    yield m
                for m in client.requests([zk.multi(
                                                zk.multi_delete(maindir + b'/subtest'),
                                                zk.multi_delete(maindir)),
                                          zk.getchildren2(maindir, True)], self.apiroutine):
                    yield m
        from time import time
        starttime = time()
        for m in self.apiroutine.executeAll([test_loop(i) for i in range(0, 100)]):
            yield m
        print('10000 loops in %r seconds, with %d connections' % (time() - starttime, len(clients)))
        for c in clients:
            for m in c.shutdown():
                yield m
Exemplo n.º 3
0
class OVSDBManager(Module):
    '''
    Manage Openflow Connections
    '''
    service = True
    # Bind to JsonRPCServer vHosts. If not None, should be a list of vHost names e.g. ``['']``
    _default_vhostbind = None
    # Only acquire information from bridges with this names
    _default_bridgenames = None

    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_conns
        self.routines.append(self.apiroutine)
        self.managed_conns = {}
        self.managed_systemids = {}
        self.managed_bridges = {}
        self.managed_routines = []
        self.endpoint_conns = {}
        self.createAPI(api(self.getconnection, self.apiroutine),
                       api(self.waitconnection, self.apiroutine),
                       api(self.getdatapathids, self.apiroutine),
                       api(self.getalldatapathids, self.apiroutine),
                       api(self.getallconnections, self.apiroutine),
                       api(self.getbridges, self.apiroutine),
                       api(self.getbridge, self.apiroutine),
                       api(self.getbridgebyuuid, self.apiroutine),
                       api(self.waitbridge, self.apiroutine),
                       api(self.waitbridgebyuuid, self.apiroutine),
                       api(self.getsystemids, self.apiroutine),
                       api(self.getallsystemids, self.apiroutine),
                       api(self.getconnectionbysystemid, self.apiroutine),
                       api(self.waitconnectionbysystemid, self.apiroutine),
                       api(self.getconnectionsbyendpoint, self.apiroutine),
                       api(self.getconnectionsbyendpointname, self.apiroutine),
                       api(self.getendpoints, self.apiroutine),
                       api(self.getallendpoints, self.apiroutine),
                       api(self.getallbridges, self.apiroutine),
                       api(self.getbridgeinfo, self.apiroutine),
                       api(self.waitbridgeinfo, self.apiroutine))
        self._synchronized = False

    def _update_bridge(self, connection, protocol, bridge_uuid, vhost):
        try:
            method, params = ovsdb.transact(
                'Open_vSwitch',
                ovsdb.wait(
                    'Bridge',
                    [["_uuid", "==", ovsdb.uuid(bridge_uuid)]],
                    ["datapath_id"], [{
                        "datapath_id": ovsdb.oset()
                    }], False, 5000),
                ovsdb.select(
                    'Bridge',
                    [["_uuid", "==", ovsdb.uuid(bridge_uuid)]],
                    ["datapath_id", "name"]))
            for m in protocol.querywithreply(method, params, connection,
                                             self.apiroutine):
                yield m
            r = self.apiroutine.jsonrpc_result[0]
            if 'error' in r:
                raise JsonRPCErrorResultException(
                    'Error while acquiring datapath-id: ' + repr(r['error']))
            r = self.apiroutine.jsonrpc_result[1]
            if 'error' in r:
                raise JsonRPCErrorResultException(
                    'Error while acquiring datapath-id: ' + repr(r['error']))
            if r['rows']:
                r0 = r['rows'][0]
                name = r0['name']
                dpid = int(r0['datapath_id'], 16)
                if self.bridgenames is None or name in self.bridgenames:
                    self.managed_bridges[connection].append(
                        (vhost, dpid, name, bridge_uuid))
                    self.managed_conns[(vhost, dpid)] = connection
                    for m in self.apiroutine.waitForSend(
                            OVSDBBridgeSetup(OVSDBBridgeSetup.UP, dpid,
                                             connection.ovsdb_systemid, name,
                                             connection, connection.connmark,
                                             vhost, bridge_uuid)):
                        yield m
        except JsonRPCProtocolException:
            pass

    def _get_bridges(self, connection, protocol):
        try:
            try:
                vhost = protocol.vhost
                if not hasattr(connection, 'ovsdb_systemid'):
                    method, params = ovsdb.transact(
                        'Open_vSwitch',
                        ovsdb.select('Open_vSwitch', [], ['external_ids']))
                    for m in protocol.querywithreply(method, params,
                                                     connection,
                                                     self.apiroutine):
                        yield m
                    result = self.apiroutine.jsonrpc_result[0]
                    system_id = ovsdb.omap_getvalue(
                        result['rows'][0]['external_ids'], 'system-id')
                    connection.ovsdb_systemid = system_id
                else:
                    system_id = connection.ovsdb_systemid
                if (vhost, system_id) in self.managed_systemids:
                    oc = self.managed_systemids[(vhost, system_id)]
                    ep = _get_endpoint(oc)
                    econns = self.endpoint_conns.get((vhost, ep))
                    if econns:
                        try:
                            econns.remove(oc)
                        except ValueError:
                            pass
                    del self.managed_systemids[(vhost, system_id)]
                self.managed_systemids[(vhost, system_id)] = connection
                self.managed_bridges[connection] = []
                ep = _get_endpoint(connection)
                self.endpoint_conns.setdefault((vhost, ep),
                                               []).append(connection)
                method, params = ovsdb.monitor(
                    'Open_vSwitch', 'ovsdb_manager_bridges_monitor',
                    {'Bridge': ovsdb.monitor_request(['name', 'datapath_id'])})
                try:
                    for m in protocol.querywithreply(method, params,
                                                     connection,
                                                     self.apiroutine):
                        yield m
                except JsonRPCErrorResultException:
                    # The monitor is already set, cancel it first
                    method, params = ovsdb.monitor_cancel(
                        'ovsdb_manager_bridges_monitor')
                    for m in protocol.querywithreply(method, params,
                                                     connection,
                                                     self.apiroutine, False):
                        yield m
                    method, params = ovsdb.monitor(
                        'Open_vSwitch', 'ovsdb_manager_bridges_monitor', {
                            'Bridge':
                            ovsdb.monitor_request(['name', 'datapath_id'])
                        })
                    for m in protocol.querywithreply(method, params,
                                                     connection,
                                                     self.apiroutine):
                        yield m
            except Exception:
                for m in self.apiroutine.waitForSend(
                        OVSDBConnectionSetup(system_id, connection,
                                             connection.connmark, vhost)):
                    yield m
                raise
            else:
                # Process initial bridges
                init_subprocesses = []
                if self.apiroutine.jsonrpc_result and 'Bridge' in self.apiroutine.jsonrpc_result:
                    init_subprocesses = [
                        self._update_bridge(connection, protocol, buuid, vhost)
                        for buuid in
                        self.apiroutine.jsonrpc_result['Bridge'].keys()
                    ]

                def init_process():
                    try:
                        with closing(
                                self.apiroutine.executeAll(init_subprocesses,
                                                           retnames=())) as g:
                            for m in g:
                                yield m
                    except Exception:
                        for m in self.apiroutine.waitForSend(
                                OVSDBConnectionSetup(system_id, connection,
                                                     connection.connmark,
                                                     vhost)):
                            yield m
                        raise
                    else:
                        for m in self.apiroutine.waitForSend(
                                OVSDBConnectionSetup(system_id, connection,
                                                     connection.connmark,
                                                     vhost)):
                            yield m

                self.apiroutine.subroutine(init_process())
            # Wait for notify
            notification = JsonRPCNotificationEvent.createMatcher(
                'update',
                connection,
                connection.connmark,
                _ismatch=lambda x: x.params[
                    0] == 'ovsdb_manager_bridges_monitor')
            conn_down = protocol.statematcher(connection)
            while True:
                yield (conn_down, notification)
                if self.apiroutine.matcher is conn_down:
                    break
                else:
                    for buuid, v in self.apiroutine.event.params[1][
                            'Bridge'].items():
                        # If a bridge's name or datapath-id is changed, we remove this bridge and add it again
                        if 'old' in v:
                            # A bridge is deleted
                            bridges = self.managed_bridges[connection]
                            for i in range(0, len(bridges)):
                                if buuid == bridges[i][3]:
                                    self.scheduler.emergesend(
                                        OVSDBBridgeSetup(
                                            OVSDBBridgeSetup.DOWN,
                                            bridges[i][1],
                                            system_id,
                                            bridges[i][2],
                                            connection,
                                            connection.connmark,
                                            vhost,
                                            bridges[i][3],
                                            new_datapath_id=int(
                                                v['new']['datapath_id'], 16)
                                            if 'new' in v
                                            and 'datapath_id' in v['new'] else
                                            None))
                                    del self.managed_conns[(vhost,
                                                            bridges[i][1])]
                                    del bridges[i]
                                    break
                        if 'new' in v:
                            # A bridge is added
                            self.apiroutine.subroutine(
                                self._update_bridge(connection, protocol,
                                                    buuid, vhost))
        except JsonRPCProtocolException:
            pass
        finally:
            del connection._ovsdb_manager_get_bridges

    def _manage_existing(self):
        for m in callAPI(self.apiroutine, "jsonrpcserver", "getconnections",
                         {}):
            yield m
        vb = self.vhostbind
        conns = self.apiroutine.retvalue
        for c in conns:
            if vb is None or c.protocol.vhost in vb:
                if not hasattr(c, '_ovsdb_manager_get_bridges'):
                    c._ovsdb_manager_get_bridges = self.apiroutine.subroutine(
                        self._get_bridges(c, c.protocol))
        matchers = [
            OVSDBConnectionSetup.createMatcher(None, c, c.connmark)
            for c in conns if vb is None or c.protocol.vhost in vb
        ]
        for m in self.apiroutine.waitForAll(*matchers):
            yield m
        self._synchronized = True
        for m in self.apiroutine.waitForSend(
                ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m

    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(),
                                                    'synchronized'), )

    def _manage_conns(self):
        try:
            self.apiroutine.subroutine(self._manage_existing())
            vb = self.vhostbind
            if vb is not None:
                conn_up = JsonRPCConnectionStateEvent.createMatcher(
                    state=JsonRPCConnectionStateEvent.CONNECTION_UP,
                    _ismatch=lambda x: x.createby.vhost in vb)
                conn_down = JsonRPCConnectionStateEvent.createMatcher(
                    state=JsonRPCConnectionStateEvent.CONNECTION_DOWN,
                    _ismatch=lambda x: x.createby.vhost in vb)
            else:
                conn_up = JsonRPCConnectionStateEvent.createMatcher(
                    state=JsonRPCConnectionStateEvent.CONNECTION_UP)
                conn_down = JsonRPCConnectionStateEvent.createMatcher(
                    state=JsonRPCConnectionStateEvent.CONNECTION_DOWN)
            while True:
                yield (conn_up, conn_down)
                if self.apiroutine.matcher is conn_up:
                    if not hasattr(self.apiroutine.event.connection,
                                   '_ovsdb_manager_get_bridges'):
                        self.apiroutine.event.connection._ovsdb_manager_get_bridges = self.apiroutine.subroutine(
                            self._get_bridges(self.apiroutine.event.connection,
                                              self.apiroutine.event.createby))
                else:
                    e = self.apiroutine.event
                    conn = e.connection
                    bridges = self.managed_bridges.get(conn)
                    if bridges is not None:
                        del self.managed_systemids[(e.createby.vhost,
                                                    conn.ovsdb_systemid)]
                        del self.managed_bridges[conn]
                        for vhost, dpid, name, buuid in bridges:
                            del self.managed_conns[(vhost, dpid)]
                            self.scheduler.emergesend(
                                OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid,
                                                 conn.ovsdb_systemid, name,
                                                 conn, conn.connmark,
                                                 e.createby.vhost, buuid))
                        econns = self.endpoint_conns.get(_get_endpoint(conn))
                        if econns is not None:
                            try:
                                econns.remove(conn)
                            except ValueError:
                                pass
        finally:
            for c in self.managed_bridges.keys():
                if hasattr(c, '_ovsdb_manager_get_bridges'):
                    c._ovsdb_manager_get_bridges.close()
                bridges = self.managed_bridges.get(c)
                if bridges is not None:
                    for vhost, dpid, name, buuid in bridges:
                        del self.managed_conns[(vhost, dpid)]
                        self.scheduler.emergesend(
                            OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN, dpid,
                                             c.ovsdb_systemid, name, c,
                                             c.connmark, c.protocol.vhost,
                                             buuid))

    def getconnection(self, datapathid, vhost=''):
        "Get current connection of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.managed_conns.get((vhost, datapathid))

    def waitconnection(self, datapathid, timeout=30, vhost=''):
        "Wait for a datapath connection"
        for m in self.getconnection(datapathid, vhost):
            yield m
        c = self.apiroutine.retvalue
        if c is None:
            for m in self.apiroutine.waitWithTimeout(
                    timeout,
                    OVSDBBridgeSetup.createMatcher(state=OVSDBBridgeSetup.UP,
                                                   datapathid=datapathid,
                                                   vhost=vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath is not connected')
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c

    def getdatapathids(self, vhost=''):
        "Get All datapath IDs"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            k[1] for k in self.managed_conns.keys() if k[0] == vhost
        ]

    def getalldatapathids(self):
        "Get all datapath IDs from any vhost. Return ``(vhost, datapathid)`` pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.keys())

    def getallconnections(self, vhost=''):
        "Get all connections from vhost. If vhost is None, return all connections from any host"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = list(self.managed_bridges.keys())
        else:
            self.apiroutine.retvalue = list(
                k for k in self.managed_bridges.keys()
                if k.protocol.vhost == vhost)

    def getbridges(self, connection):
        "Get all ``(dpid, name, _uuid)`` tuple on this connection"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            self.apiroutine.retvalue = [(dpid, name, buuid)
                                        for _, dpid, name, buuid in bridges]
        else:
            self.apiroutine.retvalue = None

    def getallbridges(self, vhost=None):
        "Get all ``(dpid, name, _uuid)`` tuple for all connections, optionally filtered by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is not None:
            self.apiroutine.retvalue = [
                (dpid, name, buuid)
                for c, bridges in self.managed_bridges.items()
                if c.protocol.vhost == vhost
                for _, dpid, name, buuid in bridges
            ]
        else:
            self.apiroutine.retvalue = [
                (dpid, name, buuid)
                for c, bridges in self.managed_bridges.items()
                for _, dpid, name, buuid in bridges
            ]

    def getbridge(self, connection, name):
        "Get datapath ID on this connection with specified name"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            for _, dpid, n, _ in bridges:
                if n == name:
                    self.apiroutine.retvalue = dpid
                    return
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = None

    def waitbridge(self, connection, name, timeout=30):
        "Wait for bridge with specified name appears and return the datapath-id"
        bnames = self.bridgenames
        if bnames is not None and name not in bnames:
            raise OVSDBBridgeNotAppearException(
                'Bridge ' + repr(name) +
                ' does not appear: it is not in the selected bridge names')
        for m in self.getbridge(connection, name):
            yield m
        if self.apiroutine.retvalue is None:
            bridge_setup = OVSDBBridgeSetup.createMatcher(
                OVSDBBridgeSetup.UP, None, None, name, connection)
            conn_down = JsonRPCConnectionStateEvent.createMatcher(
                JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection,
                connection.connmark)
            for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup,
                                                     conn_down):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) +
                                                    ' does not appear')
            elif self.apiroutine.matcher is conn_down:
                raise ConnectionResetException(
                    'Connection is down before bridge ' + repr(name) +
                    ' appears')
            else:
                self.apiroutine.retvalue = self.apiroutine.event.datapathid

    def getbridgebyuuid(self, connection, uuid):
        "Get datapath ID of bridge on this connection with specified _uuid"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            for _, dpid, _, buuid in bridges:
                if buuid == uuid:
                    self.apiroutine.retvalue = dpid
                    return
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = None

    def waitbridgebyuuid(self, connection, uuid, timeout=30):
        "Wait for bridge with specified _uuid appears and return the datapath-id"
        for m in self.getbridgebyuuid(connection, uuid):
            yield m
        if self.apiroutine.retvalue is None:
            bridge_setup = OVSDBBridgeSetup.createMatcher(
                state=OVSDBBridgeSetup.UP,
                connection=connection,
                bridgeuuid=uuid)
            conn_down = JsonRPCConnectionStateEvent.createMatcher(
                JsonRPCConnectionStateEvent.CONNECTION_DOWN, connection,
                connection.connmark)
            for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup,
                                                     conn_down):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge ' + repr(uuid) +
                                                    ' does not appear')
            elif self.apiroutine.matcher is conn_down:
                raise ConnectionResetException(
                    'Connection is down before bridge ' + repr(uuid) +
                    ' appears')
            else:
                self.apiroutine.retvalue = self.apiroutine.event.datapathid

    def getsystemids(self, vhost=''):
        "Get All system-ids"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            k[1] for k in self.managed_systemids.keys() if k[0] == vhost
        ]

    def getallsystemids(self):
        "Get all system-ids from any vhost. Return ``(vhost, system-id)`` pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_systemids.keys())

    def getconnectionbysystemid(self, systemid, vhost=''):
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.managed_systemids.get(
            (vhost, systemid))

    def waitconnectionbysystemid(self, systemid, timeout=30, vhost=''):
        "Wait for a connection with specified system-id"
        for m in self.getconnectionbysystemid(systemid, vhost):
            yield m
        c = self.apiroutine.retvalue
        if c is None:
            for m in self.apiroutine.waitWithTimeout(
                    timeout,
                    OVSDBConnectionSetup.createMatcher(systemid, None, None,
                                                       vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath is not connected')
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c

    def getconnectionsbyendpoint(self, endpoint, vhost=''):
        "Get connection by endpoint address (IP, IPv6 or UNIX socket address)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint))

    def getconnectionsbyendpointname(self, name, vhost='', timeout=30):
        "Get connection by endpoint name (Domain name, IP or IPv6 address)"
        # Resolve the name
        if not name:
            endpoint = ''
            for m in self.getconnectionbyendpoint(endpoint, vhost):
                yield m
        else:
            request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM,
                       socket.IPPROTO_TCP,
                       socket.AI_ADDRCONFIG | socket.AI_V4MAPPED)
            # Resolve hostname
            for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)):
                yield m
            for m in self.apiroutine.waitWithTimeout(
                    timeout, ResolveResponseEvent.createMatcher(request)):
                yield m
            if self.apiroutine.timeout:
                # Resolve is only allowed through asynchronous resolver
                #try:
                #    self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST)
                #except:
                raise IOError('Resolve hostname timeout: ' + name)
            else:
                if hasattr(self.apiroutine.event, 'error'):
                    raise IOError('Cannot resolve hostname: ' + name)
                resp = self.apiroutine.event.response
                for r in resp:
                    raddr = r[4]
                    if isinstance(raddr, tuple):
                        # Ignore port
                        endpoint = raddr[0]
                    else:
                        # Unix socket? This should not happen, but in case...
                        endpoint = raddr
                    for m in self.getconnectionsbyendpoint(endpoint, vhost):
                        yield m
                    if self.apiroutine.retvalue is not None:
                        break

    def getendpoints(self, vhost=''):
        "Get all endpoints for vhost"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            k[1] for k in self.endpoint_conns if k[0] == vhost
        ]

    def getallendpoints(self):
        "Get all endpoints from any vhost. Return ``(vhost, endpoint)`` pairs."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.endpoint_conns.keys())

    def getbridgeinfo(self, datapathid, vhost=''):
        "Get ``(bridgename, systemid, bridge_uuid)`` tuple from bridge datapathid"
        for m in self.getconnection(datapathid, vhost):
            yield m
        if self.apiroutine.retvalue is not None:
            c = self.apiroutine.retvalue
            bridges = self.managed_bridges.get(c)
            if bridges is not None:
                for _, dpid, n, buuid in bridges:
                    if dpid == datapathid:
                        self.apiroutine.retvalue = (n, c.ovsdb_systemid, buuid)
                        return
                self.apiroutine.retvalue = None
            else:
                self.apiroutine.retvalue = None

    def waitbridgeinfo(self, datapathid, timeout=30, vhost=''):
        "Wait for bridge with datapathid, and return ``(bridgename, systemid, bridge_uuid)`` tuple"
        for m in self.getbridgeinfo(datapathid, vhost):
            yield m
        if self.apiroutine.retvalue is None:
            for m in self.apiroutine.waitWithTimeout(
                    timeout,
                    OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.UP,
                                                   datapathid, None, None,
                                                   None, None, vhost)):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException(
                    'Bridge 0x%016x does not appear before timeout' %
                    (datapathid, ))
            e = self.apiroutine.event
            self.apiroutine.retvalue = (e.name, e.systemid, e.bridgeuuid)
Exemplo n.º 4
0
class OpenflowPortManager(Module):
    '''
    Manage Ports from Openflow Protocol
    '''
    service = True
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_ports
        self.routines.append(self.apiroutine)
        self.managed_ports = {}
        self.createAPI(api(self.getports, self.apiroutine),
                       api(self.getallports, self.apiroutine),
                       api(self.getportbyno, self.apiroutine),
                       api(self.waitportbyno, self.apiroutine),
                       api(self.getportbyname, self.apiroutine),
                       api(self.waitportbyname, self.apiroutine),
                       api(self.resync, self.apiroutine)
                       )
        self._synchronized = False
    def _get_ports(self, connection, protocol, onup = False, update = True):
        ofdef = connection.openflowdef
        dpid = connection.openflow_datapathid
        vhost = connection.protocol.vhost
        add = []
        try:
            if hasattr(ofdef, 'ofp_multipart_request'):
                # Openflow 1.3, use ofp_multipart_request to get ports
                for m in protocol.querymultipart(ofdef.ofp_multipart_request(type=ofdef.OFPMP_PORT_DESC), connection, self.apiroutine):
                    yield m
                ports = self.managed_ports.setdefault((vhost, dpid), {})
                for msg in self.apiroutine.openflow_reply:
                    for p in msg.ports:
                        add.append(p)
                        ports[p.port_no] = p
            else:
                # Openflow 1.0, use features_request
                if onup:
                    # Use the features_reply on connection setup
                    reply = connection.openflow_featuresreply
                else:
                    request = ofdef.ofp_msg()
                    request.header.type = ofdef.OFPT_FEATURES_REQUEST
                    for m in protocol.querywithreply(request):
                        yield m
                    reply = self.apiroutine.retvalue
                ports = self.managed_ports.setdefault((vhost, dpid), {})
                for p in reply.ports:
                    add.append(p)
                    ports[p.port_no] = p
            if update:
                for m in self.apiroutine.waitForSend(OpenflowPortSynchronized(connection)):
                    yield m
                for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update',
                                                                         datapathid = connection.openflow_datapathid,
                                                                         connection = connection,
                                                                         vhost = protocol.vhost,
                                                                         add = add, remove = [],
                                                                         reason = 'connected')):
                    yield m
        except ConnectionResetException:
            pass
        except OpenflowProtocolException:
            pass
    def _get_existing_ports(self):
        for m in callAPI(self.apiroutine, 'openflowmanager', 'getallconnections', {'vhost':None}):
            yield m
        with closing(self.apiroutine.executeAll([self._get_ports(c, c.protocol, False, False) for c in self.apiroutine.retvalue if c.openflow_auxiliaryid == 0])) as g:
            for m in g:
                yield m
        self._synchronized = True
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m
    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),)
    def _manage_ports(self):
        try:
            self.apiroutine.subroutine(self._get_existing_ports())
            conn_update = ModuleNotification.createMatcher('openflowmanager', 'update')
            port_status = OpenflowAsyncMessageEvent.createMatcher(of13.OFPT_PORT_STATUS, None, 0)
            while True:
                yield (conn_update, port_status)
                if self.apiroutine.matcher is port_status:
                    e = self.apiroutine.event
                    m = e.message
                    c = e.connection
                    if (c.protocol.vhost, c.openflow_datapathid) in self.managed_ports:
                        if m.reason == c.openflowdef.OFPPR_ADD:
                            # A new port is added
                            self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no] = m.desc
                            self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update',
                                                                         datapathid = c.openflow_datapathid,
                                                                         connection = c,
                                                                         vhost = c.protocol.vhost,
                                                                         add = [m.desc], remove = [],
                                                                         reason = 'add'))
                        elif m.reason == c.openflowdef.OFPPR_DELETE:
                            try:
                                del self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no]
                                self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update',
                                                                             datapathid = c.openflow_datapathid,
                                                                             connection = c,
                                                                             vhost = c.protocol.vhost,
                                                                             add = [], remove = [m.desc],
                                                                             reason = 'delete'))
                            except KeyError:
                                pass
                        elif m.reason == c.openflowdef.OFPPR_MODIFY:
                            try:
                                self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'modified',
                                                                             datapathid = c.openflow_datapathid,
                                                                             connection = c,
                                                                             vhost = c.protocol.vhost,
                                                                             old = self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no],
                                                                             new = m.desc,
                                                                             reason = 'modified'))
                            except KeyError:
                                self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update',
                                                                             datapathid = c.openflow_datapathid,
                                                                             connection = c,
                                                                             vhost = c.protocol.vhost,
                                                                             add = [m.desc], remove = [],
                                                                             reason = 'add'))
                            self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)][m.desc.port_no] = m.desc
                else:
                    e = self.apiroutine.event
                    for c in e.remove:
                        if c.openflow_auxiliaryid == 0 and (c.protocol.vhost, c.openflow_datapathid) in self.managed_ports:
                            self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update',
                                                 datapathid = c.openflow_datapathid,
                                                 connection = c,
                                                 vhost = c.protocol.vhost,
                                                 add = [], remove = list(self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)].values()),
                                                 reason = 'disconnected'))
                            del self.managed_ports[(c.protocol.vhost, c.openflow_datapathid)]
                    for c in e.add:
                        if c.openflow_auxiliaryid == 0:
                            self.apiroutine.subroutine(self._get_ports(c, c.protocol, True, True))
        finally:
            self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized'))
    def getports(self, datapathid, vhost = ''):
        "Return all ports of a specifed datapath"
        for m in self._wait_for_sync():
            yield m
        r = self.managed_ports.get((vhost, datapathid))
        if r is None:
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = list(r.values())
    def getallports(self, vhost = None):
        "Return all (datapathid, port, vhost) tuples, optionally filterd by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() for p in v.values()]
        else:
            self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() if vh == vhost for p in v.values()]
    def getportbyno(self, datapathid, portno, vhost = ''):
        "Return port with specified portno"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost)
    def _getportbyno(self, datapathid, portno, vhost = ''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            return ports.get(portno)
    def waitportbyno(self, datapathid, portno, timeout = 30, vhost = ''):
        for m in self._wait_for_sync():
            yield m
        def waitinner():
            ports = self.managed_ports.get((vhost, datapathid))
            if ports is None:
                for m in callAPI(self.apiroutine, 'openflowmanager', 'waitconnection', {'datapathid': datapathid, 'vhost':vhost, 'timeout': timeout}):
                    yield m
                c = self.apiroutine.retvalue
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    yield (OpenflowPortSynchronized.createMatcher(c),)
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    raise ConnectionResetException('Datapath %016x is not connected' % datapathid)
            if portno not in ports:
                yield (OpenflowAsyncMessageEvent.createMatcher(of13.OFPT_PORT_STATUS, datapathid, 0, _ismatch = lambda x: x.message.desc.port_no == portno),)
                self.apiroutine.retvalue = self.apiroutine.event.message.desc
            else:
                self.apiroutine.retvalue = ports[portno]
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OpenflowPortNotAppearException('Port %d does not appear on datapath %016x' % (portno, datapathid))
    def getportbyname(self, datapathid, name, vhost = ''):
        "Return port with specified portno"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost)
    def _getportbyname(self, datapathid, name, vhost = ''):
        if not isinstance(name, bytes):
            name = _bytes(name)
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for p in ports.values():
                if p.name == name:
                    return p
            return None
    def waitportbyname(self, datapathid, name, timeout = 30, vhost = ''):
        for m in self._wait_for_sync():
            yield m
        if not isinstance(name, bytes):
            name = _bytes(name)
        def waitinner():
            ports = self.managed_ports.get((vhost, datapathid))
            if ports is None:
                for m in callAPI(self.apiroutine, 'openflowmanager', 'waitconnection', {'datapathid': datapathid, 'vhost':vhost, 'timeout': timeout}):
                    yield m
                c = self.apiroutine.retvalue
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    yield (OpenflowPortSynchronized.createMatcher(c),)
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    raise ConnectionResetException('Datapath %016x is not connected' % datapathid)
            for p in ports.values():
                if p.name == name:
                    self.apiroutine.retvalue = p
                    return
            yield (OpenflowAsyncMessageEvent.createMatcher(of13.OFPT_PORT_STATUS, datapathid, 0, _ismatch = lambda x: x.message.desc.name == name),)
            self.apiroutine.retvalue = self.apiroutine.event.message.desc
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OpenflowPortNotAppearException('Port %r does not appear on datapath %016x' % (name, datapathid))
    def resync(self, datapathid, vhost = ''):
        '''
        Resync with current ports
        '''
        # Sometimes when the OpenFlow connection is very busy, PORT_STATUS message may be dropped.
        # We must deal with this and recover from it
        # Save current manged_ports
        if (vhost, datapathid) not in self.managed_ports:
            self.apiroutine.retvalue = None
            return
        else:
            last_ports = set(self.managed_ports[(vhost, datapathid)].keys())
        add = set()
        remove = set()
        ports = {}
        for _ in range(0, 10):
            for m in callAPI(self.apiroutine, 'openflowmanager', 'getconnection', {'datapathid': datapathid, 'vhost':vhost}):
                yield m
            c = self.apiroutine.retvalue
            if c is None:
                # Disconnected, will automatically resync when reconnected
                self.apiroutine.retvalue = None
                return
            ofdef = c.openflowdef
            protocol = c.protocol
            try:
                if hasattr(ofdef, 'ofp_multipart_request'):
                    # Openflow 1.3, use ofp_multipart_request to get ports
                    for m in protocol.querymultipart(ofdef.ofp_multipart_request(type=ofdef.OFPMP_PORT_DESC), c, self.apiroutine):
                        yield m
                    for msg in self.apiroutine.openflow_reply:
                        for p in msg.ports:
                            ports[p.port_no] = p
                else:
                    # Openflow 1.0, use features_request
                    request = ofdef.ofp_msg()
                    request.header.type = ofdef.OFPT_FEATURES_REQUEST
                    for m in protocol.querywithreply(request):
                        yield m
                    reply = self.apiroutine.retvalue
                    for p in reply.ports:
                        ports[p.port_no] = p
            except ConnectionResetException:
                break
            except OpenflowProtocolException:
                break
            else:
                if (vhost, datapathid) not in self.managed_ports:
                    self.apiroutine.retvalue = None
                    return
                current_ports = set(self.managed_ports[(vhost, datapathid)])
                # If a port is already removed
                remove.intersection_update(current_ports)
                # If a port is already added
                add.difference_update(current_ports)
                # If a port is not acquired, we do not add it
                acquired_keys = set(ports.keys())
                add.difference_update(acquired_keys)
                # Add and remove previous added/removed ports
                current_ports.difference_update(remove)
                current_ports.update(add)
                # If there are changed ports, the changed ports may or may not appear in the acquired port list
                # We only deal with following situations:
                # 1. If both lack ports, we add them
                # 2. If both have additional ports, we remote them
                to_add = acquired_keys.difference(current_ports.union(last_ports))
                to_remove = current_ports.intersection(last_ports).difference(acquired_keys)
                if not to_add and not to_remove and current_ports == last_ports:
                    break
                else:
                    add.update(to_add)
                    remove.update(to_remove)
                    current_ports.update(to_add)
                    current_ports.difference_update(to_remove)
                    last_ports = current_ports
        # Actual add and remove
        mports = self.managed_ports[(vhost, datapathid)]
        add_ports = []
        remove_ports = []
        for k in add:
            if k not in mports:
                add_ports.append(ports[k])
            mports[k] = ports[k]
        for k in remove:
            try:
                oldport = mports.pop(k)
            except KeyError:
                pass
            else:
                remove_ports.append(oldport)
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update',
                                                                 datapathid = datapathid,
                                                                 connection = c,
                                                                 vhost = vhost,
                                                                 add = add_ports, remove = remove_ports,
                                                                 reason = 'resync')):
            yield m
        self.apiroutine.retvalue = None
Exemplo n.º 5
0
class OpenflowPortManager(Module):
    '''
    Manage Ports from Openflow Protocol
    '''
    service = True

    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_ports
        self.routines.append(self.apiroutine)
        self.managed_ports = {}
        self.createAPI(api(self.getports, self.apiroutine),
                       api(self.getallports, self.apiroutine),
                       api(self.getportbyno, self.apiroutine),
                       api(self.waitportbyno, self.apiroutine),
                       api(self.getportbyname, self.apiroutine),
                       api(self.waitportbyname, self.apiroutine),
                       api(self.resync, self.apiroutine))
        self._synchronized = False

    def _get_ports(self, connection, protocol, onup=False, update=True):
        ofdef = connection.openflowdef
        dpid = connection.openflow_datapathid
        vhost = connection.protocol.vhost
        add = []
        try:
            if hasattr(ofdef, 'ofp_multipart_request'):
                # Openflow 1.3, use ofp_multipart_request to get ports
                for m in protocol.querymultipart(
                        ofdef.ofp_multipart_request(
                            type=ofdef.OFPMP_PORT_DESC), connection,
                        self.apiroutine):
                    yield m
                ports = self.managed_ports.setdefault((vhost, dpid), {})
                for msg in self.apiroutine.openflow_reply:
                    for p in msg.ports:
                        add.append(p)
                        ports[p.port_no] = p
            else:
                # Openflow 1.0, use features_request
                if onup:
                    # Use the features_reply on connection setup
                    reply = connection.openflow_featuresreply
                else:
                    request = ofdef.ofp_msg()
                    request.header.type = ofdef.OFPT_FEATURES_REQUEST
                    for m in protocol.querywithreply(request):
                        yield m
                    reply = self.apiroutine.retvalue
                ports = self.managed_ports.setdefault((vhost, dpid), {})
                for p in reply.ports:
                    add.append(p)
                    ports[p.port_no] = p
            if update:
                for m in self.apiroutine.waitForSend(
                        OpenflowPortSynchronized(connection)):
                    yield m
                self._logger.info(
                    "Openflow port synchronized on connection %r", connection)
                for m in self.apiroutine.waitForSend(
                        ModuleNotification(
                            self.getServiceName(),
                            'update',
                            datapathid=connection.openflow_datapathid,
                            connection=connection,
                            vhost=protocol.vhost,
                            add=add,
                            remove=[],
                            reason='connected')):
                    yield m
        except ConnectionResetException:
            pass
        except OpenflowProtocolException:
            pass

    def _get_existing_ports(self):
        for m in callAPI(self.apiroutine, 'openflowmanager',
                         'getallconnections', {'vhost': None}):
            yield m
        with closing(
                self.apiroutine.executeAll([
                    self._get_ports(c, c.protocol, False, False)
                    for c in self.apiroutine.retvalue
                    if c.openflow_auxiliaryid == 0
                ])) as g:
            for m in g:
                yield m
        self._synchronized = True
        self._logger.info("Openflow ports synchronized")
        for m in self.apiroutine.waitForSend(
                ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m

    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(),
                                                    'synchronized'), )

    def _manage_ports(self):
        try:
            self.apiroutine.subroutine(self._get_existing_ports())
            conn_update = ModuleNotification.createMatcher(
                'openflowmanager', 'update')
            port_status = OpenflowAsyncMessageEvent.createMatcher(
                of13.OFPT_PORT_STATUS, None, 0)
            while True:
                yield (conn_update, port_status)
                if self.apiroutine.matcher is port_status:
                    e = self.apiroutine.event
                    m = e.message
                    c = e.connection
                    if (c.protocol.vhost,
                            c.openflow_datapathid) in self.managed_ports:
                        if m.reason == c.openflowdef.OFPPR_ADD:
                            # A new port is added
                            self.managed_ports[(c.protocol.vhost,
                                                c.openflow_datapathid
                                                )][m.desc.port_no] = m.desc
                            self.scheduler.emergesend(
                                ModuleNotification(
                                    self.getServiceName(),
                                    'update',
                                    datapathid=c.openflow_datapathid,
                                    connection=c,
                                    vhost=c.protocol.vhost,
                                    add=[m.desc],
                                    remove=[],
                                    reason='add'))
                        elif m.reason == c.openflowdef.OFPPR_DELETE:
                            try:
                                del self.managed_ports[(
                                    c.protocol.vhost,
                                    c.openflow_datapathid)][m.desc.port_no]
                                self.scheduler.emergesend(
                                    ModuleNotification(
                                        self.getServiceName(),
                                        'update',
                                        datapathid=c.openflow_datapathid,
                                        connection=c,
                                        vhost=c.protocol.vhost,
                                        add=[],
                                        remove=[m.desc],
                                        reason='delete'))
                            except KeyError:
                                pass
                        elif m.reason == c.openflowdef.OFPPR_MODIFY:
                            try:
                                self.scheduler.emergesend(
                                    ModuleNotification(
                                        self.getServiceName(),
                                        'modified',
                                        datapathid=c.openflow_datapathid,
                                        connection=c,
                                        vhost=c.protocol.vhost,
                                        old=self.managed_ports[(
                                            c.protocol.vhost,
                                            c.openflow_datapathid
                                        )][m.desc.port_no],
                                        new=m.desc,
                                        reason='modified'))
                            except KeyError:
                                self.scheduler.emergesend(
                                    ModuleNotification(
                                        self.getServiceName(),
                                        'update',
                                        datapathid=c.openflow_datapathid,
                                        connection=c,
                                        vhost=c.protocol.vhost,
                                        add=[m.desc],
                                        remove=[],
                                        reason='add'))
                            self.managed_ports[(c.protocol.vhost,
                                                c.openflow_datapathid
                                                )][m.desc.port_no] = m.desc
                else:
                    e = self.apiroutine.event
                    for c in e.remove:
                        if c.openflow_auxiliaryid == 0 and (
                                c.protocol.vhost,
                                c.openflow_datapathid) in self.managed_ports:
                            self.scheduler.emergesend(
                                ModuleNotification(
                                    self.getServiceName(),
                                    'update',
                                    datapathid=c.openflow_datapathid,
                                    connection=c,
                                    vhost=c.protocol.vhost,
                                    add=[],
                                    remove=list(self.managed_ports[(
                                        c.protocol.vhost,
                                        c.openflow_datapathid)].values()),
                                    reason='disconnected'))
                            del self.managed_ports[(c.protocol.vhost,
                                                    c.openflow_datapathid)]
                    for c in e.add:
                        if c.openflow_auxiliaryid == 0:
                            self.apiroutine.subroutine(
                                self._get_ports(c, c.protocol, True, True))
        finally:
            self.scheduler.emergesend(
                ModuleNotification(self.getServiceName(), 'unsynchronized'))

    def getports(self, datapathid, vhost=''):
        "Return all ports of a specifed datapath"
        for m in self._wait_for_sync():
            yield m
        r = self.managed_ports.get((vhost, datapathid))
        if r is None:
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = list(r.values())

    def getallports(self, vhost=None):
        "Return all ``(datapathid, port, vhost)`` tuples, optionally filterd by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = [
                (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items()
                for p in v.values()
            ]
        else:
            self.apiroutine.retvalue = [
                (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items()
                if vh == vhost for p in v.values()
            ]

    def getportbyno(self, datapathid, portno, vhost=''):
        "Return port with specified OpenFlow portno"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost)

    def _getportbyno(self, datapathid, portno, vhost=''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            return ports.get(portno)

    def waitportbyno(self, datapathid, portno, timeout=30, vhost=''):
        """
        Wait for the specified OpenFlow portno to appear, or until timeout.
        """
        for m in self._wait_for_sync():
            yield m

        def waitinner():
            ports = self.managed_ports.get((vhost, datapathid))
            if ports is None:
                for m in callAPI(self.apiroutine, 'openflowmanager',
                                 'waitconnection', {
                                     'datapathid': datapathid,
                                     'vhost': vhost,
                                     'timeout': timeout
                                 }):
                    yield m
                c = self.apiroutine.retvalue
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    yield (OpenflowPortSynchronized.createMatcher(c), )
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    raise ConnectionResetException(
                        'Datapath %016x is not connected' % datapathid)
            if portno not in ports:
                yield (OpenflowAsyncMessageEvent.createMatcher(
                    of13.OFPT_PORT_STATUS,
                    datapathid,
                    0,
                    _ismatch=lambda x: x.message.desc.port_no == portno), )
                self.apiroutine.retvalue = self.apiroutine.event.message.desc
            else:
                self.apiroutine.retvalue = ports[portno]

        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OpenflowPortNotAppearException(
                'Port %d does not appear on datapath %016x' %
                (portno, datapathid))

    def getportbyname(self, datapathid, name, vhost=''):
        "Return port with specified port name"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost)

    def _getportbyname(self, datapathid, name, vhost=''):
        if not isinstance(name, bytes):
            name = _bytes(name)
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for p in ports.values():
                if p.name == name:
                    return p
            return None

    def waitportbyname(self, datapathid, name, timeout=30, vhost=''):
        """
        Wait for a port with the specified port name to appear, or until timeout
        """
        for m in self._wait_for_sync():
            yield m
        if not isinstance(name, bytes):
            name = _bytes(name)

        def waitinner():
            ports = self.managed_ports.get((vhost, datapathid))
            if ports is None:
                for m in callAPI(self.apiroutine, 'openflowmanager',
                                 'waitconnection', {
                                     'datapathid': datapathid,
                                     'vhost': vhost,
                                     'timeout': timeout
                                 }):
                    yield m
                c = self.apiroutine.retvalue
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    yield (OpenflowPortSynchronized.createMatcher(c), )
                ports = self.managed_ports.get((vhost, datapathid))
                if ports is None:
                    raise ConnectionResetException(
                        'Datapath %016x is not connected' % datapathid)
            for p in ports.values():
                if p.name == name:
                    self.apiroutine.retvalue = p
                    return
            yield (OpenflowAsyncMessageEvent.createMatcher(
                of13.OFPT_PORT_STATUS,
                datapathid,
                0,
                _ismatch=lambda x: x.message.desc.name == name), )
            self.apiroutine.retvalue = self.apiroutine.event.message.desc

        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OpenflowPortNotAppearException(
                'Port %r does not appear on datapath %016x' %
                (name, datapathid))

    def resync(self, datapathid, vhost=''):
        '''
        Resync with current ports
        '''
        # Sometimes when the OpenFlow connection is very busy, PORT_STATUS message may be dropped.
        # We must deal with this and recover from it
        # Save current manged_ports
        if (vhost, datapathid) not in self.managed_ports:
            self.apiroutine.retvalue = None
            return
        else:
            last_ports = set(self.managed_ports[(vhost, datapathid)].keys())
        add = set()
        remove = set()
        ports = {}
        for _ in range(0, 10):
            for m in callAPI(self.apiroutine, 'openflowmanager',
                             'getconnection', {
                                 'datapathid': datapathid,
                                 'vhost': vhost
                             }):
                yield m
            c = self.apiroutine.retvalue
            if c is None:
                # Disconnected, will automatically resync when reconnected
                self.apiroutine.retvalue = None
                return
            ofdef = c.openflowdef
            protocol = c.protocol
            try:
                if hasattr(ofdef, 'ofp_multipart_request'):
                    # Openflow 1.3, use ofp_multipart_request to get ports
                    for m in protocol.querymultipart(
                            ofdef.ofp_multipart_request(
                                type=ofdef.OFPMP_PORT_DESC), c,
                            self.apiroutine):
                        yield m
                    for msg in self.apiroutine.openflow_reply:
                        for p in msg.ports:
                            ports[p.port_no] = p
                else:
                    # Openflow 1.0, use features_request
                    request = ofdef.ofp_msg()
                    request.header.type = ofdef.OFPT_FEATURES_REQUEST
                    for m in protocol.querywithreply(request):
                        yield m
                    reply = self.apiroutine.retvalue
                    for p in reply.ports:
                        ports[p.port_no] = p
            except ConnectionResetException:
                break
            except OpenflowProtocolException:
                break
            else:
                if (vhost, datapathid) not in self.managed_ports:
                    self.apiroutine.retvalue = None
                    return
                current_ports = set(self.managed_ports[(vhost, datapathid)])
                # If a port is already removed
                remove.intersection_update(current_ports)
                # If a port is already added
                add.difference_update(current_ports)
                # If a port is not acquired, we do not add it
                acquired_keys = set(ports.keys())
                add.difference_update(acquired_keys)
                # Add and remove previous added/removed ports
                current_ports.difference_update(remove)
                current_ports.update(add)
                # If there are changed ports, the changed ports may or may not appear in the acquired port list
                # We only deal with following situations:
                # 1. If both lack ports, we add them
                # 2. If both have additional ports, we remote them
                to_add = acquired_keys.difference(
                    current_ports.union(last_ports))
                to_remove = current_ports.intersection(last_ports).difference(
                    acquired_keys)
                if not to_add and not to_remove and current_ports == last_ports:
                    break
                else:
                    add.update(to_add)
                    remove.update(to_remove)
                    current_ports.update(to_add)
                    current_ports.difference_update(to_remove)
                    last_ports = current_ports
        # Actual add and remove
        mports = self.managed_ports[(vhost, datapathid)]
        add_ports = []
        remove_ports = []
        for k in add:
            if k not in mports:
                add_ports.append(ports[k])
            mports[k] = ports[k]
        for k in remove:
            try:
                oldport = mports.pop(k)
            except KeyError:
                pass
            else:
                remove_ports.append(oldport)
        for m in self.apiroutine.waitForSend(
                ModuleNotification(self.getServiceName(),
                                   'update',
                                   datapathid=datapathid,
                                   connection=c,
                                   vhost=vhost,
                                   add=add_ports,
                                   remove=remove_ports,
                                   reason='resync')):
            yield m
        self.apiroutine.retvalue = None
Exemplo n.º 6
0
class OVSDBPortManager(Module):
    '''
    Manage Ports from OVSDB Protocol
    '''
    service = True

    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_ports
        self.routines.append(self.apiroutine)
        self.managed_ports = {}
        self.managed_ids = {}
        self.monitor_routines = set()
        self.ports_uuids = {}
        self.wait_portnos = {}
        self.wait_names = {}
        self.wait_ids = {}
        self.bridge_datapathid = {}
        self.createAPI(api(self.getports, self.apiroutine),
                       api(self.getallports, self.apiroutine),
                       api(self.getportbyid, self.apiroutine),
                       api(self.waitportbyid, self.apiroutine),
                       api(self.getportbyname, self.apiroutine),
                       api(self.waitportbyname, self.apiroutine),
                       api(self.getportbyno, self.apiroutine),
                       api(self.waitportbyno, self.apiroutine),
                       api(self.resync, self.apiroutine))
        self._synchronized = False

    def _get_interface_info(self, connection, protocol, buuid, interface_uuid,
                            port_uuid):
        try:
            method, params = ovsdb.transact(
                'Open_vSwitch',
                ovsdb.wait(
                    'Interface',
                    [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ofport"],
                    [{
                        "ofport": ovsdb.oset()
                    }], False, 5000),
                ovsdb.wait(
                    'Interface',
                    [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ofport"],
                    [{
                        "ofport": -1
                    }], False, 0),
                ovsdb.wait(
                    'Interface',
                    [["_uuid", "==", ovsdb.uuid(interface_uuid)]], ["ifindex"],
                    [{
                        "ifindex": ovsdb.oset()
                    }], False, 5000),
                ovsdb.select(
                    'Interface',
                    [["_uuid", "==", ovsdb.uuid(interface_uuid)]], [
                        "_uuid", "name", "ifindex", "ofport", "type",
                        "external_ids"
                    ]))
            for m in protocol.querywithreply(method, params, connection,
                                             self.apiroutine):
                yield m
            r = self.apiroutine.jsonrpc_result[0]
            if 'error' in r:
                raise JsonRPCErrorResultException(
                    'Error while acquiring interface: ' + repr(r['error']))
            r = self.apiroutine.jsonrpc_result[1]
            if 'error' in r:
                raise JsonRPCErrorResultException(
                    'Error while acquiring interface: ' + repr(r['error']))
            r = self.apiroutine.jsonrpc_result[2]
            if 'error' in r:
                # Ignore this port because it is in an error state
                self.apiroutine.retvalue = []
                return
            r = self.apiroutine.jsonrpc_result[3]
            if 'error' in r:
                raise JsonRPCErrorResultException(
                    'Error while acquiring interface: ' + repr(r['error']))
            if not r['rows']:
                self.apiroutine.retvalue = []
                return
            r0 = r['rows'][0]
            if r0['ofport'] < 0:
                # Ignore this port because it is in an error state
                self.apiroutine.retvalue = []
                return
            r0['_uuid'] = r0['_uuid'][1]
            r0['ifindex'] = ovsdb.getoptional(r0['ifindex'])
            r0['external_ids'] = ovsdb.getdict(r0['external_ids'])
            if buuid not in self.bridge_datapathid:
                self.apiroutine.retvalue = []
                return
            else:
                datapath_id = self.bridge_datapathid[buuid]
            if 'iface-id' in r0['external_ids']:
                eid = r0['external_ids']['iface-id']
                r0['id'] = eid
                id_ports = self.managed_ids.setdefault((protocol.vhost, eid),
                                                       [])
                id_ports.append((datapath_id, r0))
            else:
                r0['id'] = None
            self.managed_ports.setdefault((protocol.vhost, datapath_id),
                                          []).append((port_uuid, r0))
            notify = False
            if (protocol.vhost, datapath_id,
                    r0['ofport']) in self.wait_portnos:
                notify = True
                del self.wait_portnos[(protocol.vhost, datapath_id,
                                       r0['ofport'])]
            if (protocol.vhost, datapath_id, r0['name']) in self.wait_names:
                notify = True
                del self.wait_names[(protocol.vhost, datapath_id, r0['name'])]
            if (protocol.vhost, r0['id']) in self.wait_ids:
                notify = True
                del self.wait_ids[(protocol.vhost, r0['id'])]
            if notify:
                for m in self.apiroutine.waitForSend(
                        OVSDBPortUpNotification(connection,
                                                r0['name'],
                                                r0['ofport'],
                                                r0['id'],
                                                protocol.vhost,
                                                datapath_id,
                                                port=r0)):
                    yield m
            self.apiroutine.retvalue = [r0]
        except JsonRPCProtocolException:
            self.apiroutine.retvalue = []

    def _remove_interface_id(self, connection, protocol, datapath_id, port):
        eid = port['id']
        eid_list = self.managed_ids.get((protocol.vhost, eid))
        for i in range(0, len(eid_list)):
            if eid_list[i][1]['_uuid'] == port['_uuid']:
                del eid_list[i]
                break

    def _remove_interface(self, connection, protocol, datapath_id,
                          interface_uuid, port_uuid):
        ports = self.managed_ports.get((protocol.vhost, datapath_id))
        r = None
        if ports is not None:
            for i in range(0, len(ports)):
                if ports[i][1]['_uuid'] == interface_uuid:
                    r = ports[i][1]
                    if r['id']:
                        self._remove_interface_id(connection, protocol,
                                                  datapath_id, r)
                    del ports[i]
                    break
            if not ports:
                del self.managed_ports[(protocol.vhost, datapath_id)]
        return r

    def _remove_all_interface(self, connection, protocol, datapath_id,
                              port_uuid, buuid):
        ports = self.managed_ports.get((protocol.vhost, datapath_id))
        if ports is not None:
            removed_ports = [r for puuid, r in ports if puuid == port_uuid]
            not_removed_ports = [(puuid, r) for puuid, r in ports
                                 if puuid != port_uuid]
            ports[:len(not_removed_ports)] = not_removed_ports
            del ports[len(not_removed_ports):]
            for r in removed_ports:
                if r['id']:
                    self._remove_interface_id(connection, protocol,
                                              datapath_id, r)
            if not ports:
                del self.managed_ports[(protocol.vhost, datapath_id)]
            return removed_ports
        if port_uuid in self.ports_uuids and self.ports_uuids[
                port_uuid] == buuid:
            del self.ports_uuids[port_uuid]
        return []

    def _update_interfaces(self,
                           connection,
                           protocol,
                           updateinfo,
                           update=True):
        """
        There are several kinds of updates, they may appear together:
        
        1. New bridge created (or from initial updateinfo). We should add all the interfaces to the list.
        
        2. Bridge removed. Remove all the ports.
        
        3. name and datapath_id may be changed. We will consider this as a new bridge created, and an old
           bridge removed.
        
        4. Bridge ports modification, i.e. add/remove ports.
           a) Normally a port record is created/deleted together. A port record cannot exist without a
              bridge containing it.
               
           b) It is also possible that a port is removed from one bridge and added to another bridge, in
              this case the ports do not appear in the updateinfo
            
        5. Port interfaces modification, i.e. add/remove interfaces. The bridge record may not appear in this
           situation.
           
        We must consider these situations carefully and process them in correct order.
        """
        port_update = updateinfo.get('Port', {})
        bridge_update = updateinfo.get('Bridge', {})
        working_routines = []

        def process_bridge(buuid, uo):
            try:
                nv = uo['new']
                if 'datapath_id' in nv:
                    if ovsdb.getoptional(nv['datapath_id']) is None:
                        # This bridge is not initialized. Wait for the bridge to be initialized.
                        for m in callAPI(
                                self.apiroutine, 'ovsdbmanager', 'waitbridge',
                            {
                                'connection': connection,
                                'name': nv['name'],
                                'timeout': 5
                            }):
                            yield m
                        datapath_id = self.apiroutine.retvalue
                    else:
                        datapath_id = int(nv['datapath_id'], 16)
                    self.bridge_datapathid[buuid] = datapath_id
                elif buuid in self.bridge_datapathid:
                    datapath_id = self.bridge_datapathid[buuid]
                else:
                    # This should not happen, but just in case...
                    for m in callAPI(self.apiroutine, 'ovsdbmanager',
                                     'waitbridge', {
                                         'connection': connection,
                                         'name': nv['name'],
                                         'timeout': 5
                                     }):
                        yield m
                    datapath_id = self.apiroutine.retvalue
                    self.bridge_datapathid[buuid] = datapath_id
                if 'ports' in nv:
                    nset = set((p for _, p in ovsdb.getlist(nv['ports'])))
                else:
                    nset = set()
                if 'old' in uo:
                    ov = uo['old']
                    if 'ports' in ov:
                        oset = set((p for _, p in ovsdb.getlist(ov['ports'])))
                    else:
                        # new ports are not really added; it is only sent because datapath_id is modified
                        nset = set()
                        oset = set()
                    if 'datapath_id' in ov and ovsdb.getoptional(
                            ov['datapath_id']) is not None:
                        old_datapathid = int(ov['datapath_id'], 16)
                    else:
                        old_datapathid = datapath_id
                else:
                    oset = set()
                    old_datapathid = datapath_id
                # For every deleted port, remove the interfaces with this port _uuid
                remove = []
                add_routine = []
                for puuid in oset - nset:
                    remove += self._remove_all_interface(
                        connection, protocol, old_datapathid, puuid, buuid)
                # For every port not changed, check if the interfaces are modified;
                for puuid in oset.intersection(nset):
                    if puuid in port_update:
                        # The port is modified, there should be an 'old' set and 'new' set
                        pu = port_update[puuid]
                        if 'old' in pu:
                            poset = set((p for _, p in ovsdb.getlist(
                                pu['old']['interfaces'])))
                        else:
                            poset = set()
                        if 'new' in pu:
                            pnset = set((p for _, p in ovsdb.getlist(
                                pu['new']['interfaces'])))
                        else:
                            pnset = set()
                        # Remove old interfaces
                        remove += [
                            r for r in (self._remove_interface(
                                connection, protocol, datapath_id, iuuid,
                                puuid) for iuuid in (poset - pnset))
                            if r is not None
                        ]
                        # Prepare to add new interfaces
                        add_routine += [
                            self._get_interface_info(connection, protocol,
                                                     buuid, iuuid, puuid)
                            for iuuid in (pnset - poset)
                        ]
                # For every port added, add the interfaces
                def add_port_interfaces(puuid):
                    # If the uuid does not appear in update info, we have no choice but to query interfaces with select
                    # we cannot use data from other bridges; the port may be moved from a bridge which is not tracked
                    try:
                        method, params = ovsdb.transact(
                            'Open_vSwitch',
                            ovsdb.select('Port',
                                         [["_uuid", "==",
                                           ovsdb.uuid(puuid)]],
                                         ["interfaces"]))
                        for m in protocol.querywithreply(
                                method, params, connection, self.apiroutine):
                            yield m
                        r = self.apiroutine.jsonrpc_result[0]
                        if 'error' in r:
                            raise JsonRPCErrorResultException(
                                'Error when query interfaces from port ' +
                                repr(puuid) + ': ' + r['error'])
                        if r['rows']:
                            interfaces = ovsdb.getlist(
                                r['rows'][0]['interfaces'])
                            with closing(
                                    self.apiroutine.executeAll([
                                        self._get_interface_info(
                                            connection, protocol, buuid, iuuid,
                                            puuid) for _, iuuid in interfaces
                                    ])) as g:
                                for m in g:
                                    yield m
                            self.apiroutine.retvalue = list(
                                itertools.chain(
                                    r[0] for r in self.apiroutine.retvalue))
                        else:
                            self.apiroutine.retvalue = []
                    except JsonRPCProtocolException:
                        self.apiroutine.retvalue = []
                    except ConnectionResetException:
                        self.apiroutine.retvalue = []

                for puuid in nset - oset:
                    self.ports_uuids[puuid] = buuid
                    if puuid in port_update and 'new' in port_update[puuid] \
                            and 'old' not in port_update[puuid]:
                        # Add all the interfaces in 'new'
                        interfaces = ovsdb.getlist(
                            port_update[puuid]['new']['interfaces'])
                        add_routine += [
                            self._get_interface_info(connection, protocol,
                                                     buuid, iuuid, puuid)
                            for _, iuuid in interfaces
                        ]
                    else:
                        add_routine.append(add_port_interfaces(puuid))
                # Execute the add_routine
                try:
                    with closing(self.apiroutine.executeAll(add_routine)) as g:
                        for m in g:
                            yield m
                except:
                    add = []
                    raise
                else:
                    add = list(
                        itertools.chain(r[0]
                                        for r in self.apiroutine.retvalue))
                finally:
                    if update:
                        self.scheduler.emergesend(
                            ModuleNotification(self.getServiceName(),
                                               'update',
                                               datapathid=datapath_id,
                                               connection=connection,
                                               vhost=protocol.vhost,
                                               add=add,
                                               remove=remove,
                                               reason='bridgemodify'
                                               if 'old' in uo else 'bridgeup'))
            except JsonRPCProtocolException:
                pass
            except ConnectionResetException:
                pass
            except OVSDBBridgeNotAppearException:
                pass

        ignore_ports = set()
        for buuid, uo in bridge_update.items():
            # Bridge removals are ignored because we process OVSDBBridgeSetup event instead
            if 'old' in uo:
                if 'ports' in uo['old']:
                    oset = set(
                        (puuid
                         for _, puuid in ovsdb.getlist(uo['old']['ports'])))
                    ignore_ports.update(oset)
                if 'new' not in uo:
                    if buuid in self.bridge_datapathid:
                        del self.bridge_datapathid[buuid]
            if 'new' in uo:
                # If bridge contains this port is updated, we process the port update totally in bridge,
                # so we ignore it later
                if 'ports' in uo['new']:
                    nset = set(
                        (puuid
                         for _, puuid in ovsdb.getlist(uo['new']['ports'])))
                    ignore_ports.update(nset)
                working_routines.append(process_bridge(buuid, uo))

        def process_port(buuid, port_uuid, interfaces, remove_ids):
            if buuid not in self.bridge_datapathid:
                return
            datapath_id = self.bridge_datapathid[buuid]
            ports = self.managed_ports.get((protocol.vhost, datapath_id))
            remove = []
            if ports is not None:
                remove = [p for _, p in ports if p['_uuid'] in remove_ids]
                not_remove = [(_, p) for _, p in ports
                              if p['_uuid'] not in remove_ids]
                ports[:len(not_remove)] = not_remove
                del ports[len(not_remove):]
            if interfaces:
                try:
                    with closing(
                            self.apiroutine.executeAll([
                                self._get_interface_info(
                                    connection, protocol, buuid, iuuid,
                                    port_uuid) for iuuid in interfaces
                            ])) as g:
                        for m in g:
                            yield m
                    add = list(
                        itertools.chain(
                            (r[0] for r in self.apiroutine.retvalue if r[0])))
                except Exception:
                    self._logger.warning("Cannot get new port information",
                                         exc_info=True)
                    add = []
            else:
                add = []
            if update:
                for m in self.apiroutine.waitForSend(
                        ModuleNotification(self.getServiceName(),
                                           'update',
                                           datapathid=datapath_id,
                                           connection=connection,
                                           vhost=protocol.vhost,
                                           add=add,
                                           remove=remove,
                                           reason='bridgemodify')):
                    yield m

        for puuid, po in port_update.items():
            if puuid not in ignore_ports:
                bridge_id = self.ports_uuids.get(puuid)
                if bridge_id is not None:
                    datapath_id = self.bridge_datapathid[bridge_id]
                    if datapath_id is not None:
                        # This port is modified
                        if 'new' in po:
                            nset = set((iuuid for _, iuuid in ovsdb.getlist(
                                po['new']['interfaces'])))
                        else:
                            nset = set()
                        if 'old' in po:
                            oset = set((iuuid for _, iuuid in ovsdb.getlist(
                                po['old']['interfaces'])))
                        else:
                            oset = set()
                        working_routines.append(
                            process_port(bridge_id, puuid, nset - oset,
                                         oset - nset))
        if update:
            for r in working_routines:
                self.apiroutine.subroutine(r)
        else:
            try:
                with closing(
                        self.apiroutine.executeAll(working_routines, None,
                                                   ())) as g:
                    for m in g:
                        yield m
            finally:
                self.scheduler.emergesend(
                    OVSDBConnectionPortsSynchronized(connection))

    def _get_ports(self, connection, protocol):
        try:
            try:
                method, params = ovsdb.monitor(
                    'Open_vSwitch', 'ovsdb_port_manager_interfaces_monitor', {
                        'Bridge': [
                            ovsdb.monitor_request(
                                ["name", "datapath_id", "ports"])
                        ],
                        'Port': [ovsdb.monitor_request(["interfaces"])]
                    })
                try:
                    for m in protocol.querywithreply(method, params,
                                                     connection,
                                                     self.apiroutine):
                        yield m
                except JsonRPCErrorResultException:
                    # The monitor is already set, cancel it first
                    method2, params2 = ovsdb.monitor_cancel(
                        'ovsdb_port_manager_interfaces_monitor')
                    for m in protocol.querywithreply(method2, params2,
                                                     connection,
                                                     self.apiroutine, False):
                        yield m
                    for m in protocol.querywithreply(method, params,
                                                     connection,
                                                     self.apiroutine):
                        yield m
                r = self.apiroutine.jsonrpc_result
            except:

                def _msg():
                    for m in self.apiroutine.waitForSend(
                            OVSDBConnectionPortsSynchronized(connection)):
                        yield m

                self.apiroutine.subroutine(_msg(), False)
                raise
            # This is the initial state, it should contains all the ids of ports and interfaces
            self.apiroutine.subroutine(
                self._update_interfaces(connection, protocol, r, False))
            update_matcher = JsonRPCNotificationEvent.createMatcher(
                'update',
                connection,
                connection.connmark,
                _ismatch=lambda x: x.params[
                    0] == 'ovsdb_port_manager_interfaces_monitor')
            conn_state = protocol.statematcher(connection)
            while True:
                yield (update_matcher, conn_state)
                if self.apiroutine.matcher is conn_state:
                    break
                else:
                    self.apiroutine.subroutine(
                        self._update_interfaces(
                            connection, protocol,
                            self.apiroutine.event.params[1], True))
        except JsonRPCProtocolException:
            pass
        finally:
            if self.apiroutine.currentroutine in self.monitor_routines:
                self.monitor_routines.remove(self.apiroutine.currentroutine)

    def _get_existing_ports(self):
        for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getallconnections',
                         {'vhost': None}):
            yield m
        matchers = []
        for c in self.apiroutine.retvalue:
            self.monitor_routines.add(
                self.apiroutine.subroutine(self._get_ports(c, c.protocol)))
            matchers.append(OVSDBConnectionPortsSynchronized.createMatcher(c))
        for m in self.apiroutine.waitForAll(*matchers):
            yield m
        self._synchronized = True
        for m in self.apiroutine.waitForSend(
                ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m

    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(),
                                                    'synchronized'), )

    def _manage_ports(self):
        try:
            self.apiroutine.subroutine(self._get_existing_ports())
            connsetup = OVSDBConnectionSetup.createMatcher()
            bridgedown = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.DOWN)
            while True:
                yield (connsetup, bridgedown)
                e = self.apiroutine.event
                if self.apiroutine.matcher is connsetup:
                    self.monitor_routines.add(
                        self.apiroutine.subroutine(
                            self._get_ports(e.connection,
                                            e.connection.protocol)))
                else:
                    # Remove ports of the bridge
                    ports = self.managed_ports.get((e.vhost, e.datapathid))
                    if ports is not None:
                        ports_original = ports
                        ports = [p for _, p in ports]
                        for p in ports:
                            if p['id']:
                                self._remove_interface_id(
                                    e.connection, e.connection.protocol,
                                    e.datapathid, p)
                        newdpid = getattr(e, 'new_datapath_id', None)
                        buuid = e.bridgeuuid
                        if newdpid is not None:
                            # This bridge changes its datapath id
                            if buuid in self.bridge_datapathid and self.bridge_datapathid[
                                    buuid] == e.datapathid:
                                self.bridge_datapathid[buuid] = newdpid

                            def re_add_interfaces():
                                with closing(
                                        self.apiroutine.executeAll([
                                            self._get_interface_info(
                                                e.connection,
                                                e.connection.protocol, buuid,
                                                r['_uuid'], puuid)
                                            for puuid, r in ports_original
                                        ])) as g:
                                    for m in g:
                                        yield m
                                add = list(
                                    itertools.chain(
                                        r[0]
                                        for r in self.apiroutine.retvalue))
                                for m in self.apiroutine.waitForSend(
                                        ModuleNotification(
                                            self.getServiceName(),
                                            'update',
                                            datapathid=e.datapathid,
                                            connection=e.connection,
                                            vhost=e.vhost,
                                            add=add,
                                            remove=[],
                                            reason='bridgeup')):
                                    yield m

                            self.apiroutine.subroutine(re_add_interfaces())
                        else:
                            # The ports are removed
                            for puuid, _ in ports_original:
                                if puuid in self.ports_uuids[
                                        puuid] and self.ports_uuids[
                                            puuid] == buuid:
                                    del self.ports_uuids[puuid]
                        del self.managed_ports[(e.vhost, e.datapathid)]
                        self.scheduler.emergesend(
                            ModuleNotification(self.getServiceName(),
                                               'update',
                                               datapathid=e.datapathid,
                                               connection=e.connection,
                                               vhost=e.vhost,
                                               add=[],
                                               remove=ports,
                                               reason='bridgedown'))
        finally:
            for r in list(self.monitor_routines):
                r.close()
            self.scheduler.emergesend(
                ModuleNotification(self.getServiceName(), 'unsynchronized'))

    def getports(self, datapathid, vhost=''):
        "Return all ports of a specifed datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            p for _, p in self.managed_ports.get((vhost, datapathid), [])
        ]

    def getallports(self, vhost=None):
        "Return all ``(datapathid, port, vhost)`` tuples, optionally filterd by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = [
                (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items()
                for _, p in v
            ]
        else:
            self.apiroutine.retvalue = [
                (dpid, p, vh) for (vh, dpid), v in self.managed_ports.items()
                if vh == vhost for _, p in v
            ]

    def getportbyno(self, datapathid, portno, vhost=''):
        "Return port with specified portno"
        portno &= 0xffff
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost)

    def _getportbyno(self, datapathid, portno, vhost=''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for _, p in ports:
                if p['ofport'] == portno:
                    return p
            return None

    def waitportbyno(self, datapathid, portno, timeout=30, vhost=''):
        "Wait for port with specified portno"
        portno &= 0xffff
        for m in self._wait_for_sync():
            yield m

        def waitinner():
            p = self._getportbyno(datapathid, portno, vhost)
            if p is not None:
                self.apiroutine.retvalue = p
            else:
                try:
                    self.wait_portnos[(vhost, datapathid, portno)] = \
                            self.wait_portnos.get((vhost, datapathid, portno),0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(
                        None, None, portno, None, vhost, datapathid), )
                except:
                    v = self.wait_portnos.get((vhost, datapathid, portno))
                    if v is not None:
                        if v <= 1:
                            del self.wait_portnos[(vhost, datapathid, portno)]
                        else:
                            self.wait_portnos[(vhost, datapathid,
                                               portno)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = self.apiroutine.event.port

        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException(
                'Port ' + repr(portno) + ' does not appear before timeout')

    def getportbyname(self, datapathid, name, vhost=''):
        "Return port with specified name"
        if isinstance(name, bytes):
            name = name.decode('utf-8')
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost)

    def _getportbyname(self, datapathid, name, vhost=''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for _, p in ports:
                if p['name'] == name:
                    return p
            return None

    def waitportbyname(self, datapathid, name, timeout=30, vhost=''):
        "Wait for port with specified name"
        for m in self._wait_for_sync():
            yield m

        def waitinner():
            p = self._getportbyname(datapathid, name, vhost)
            if p is not None:
                self.apiroutine.retvalue = p
            else:
                try:
                    self.wait_names[(vhost, datapathid, name)] = \
                            self.wait_portnos.get((vhost, datapathid, name) ,0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(
                        None, name, None, None, vhost, datapathid), )
                except:
                    v = self.wait_names.get((vhost, datapathid, name))
                    if v is not None:
                        if v <= 1:
                            del self.wait_names[(vhost, datapathid, name)]
                        else:
                            self.wait_names[(vhost, datapathid, name)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = self.apiroutine.event.port

        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException(
                'Port ' + repr(name) + ' does not appear before timeout')

    def getportbyid(self, id, vhost=''):
        "Return port with the specified id. The return value is a pair: ``(datapath_id, port)``"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine = self._getportbyid(id, vhost)

    def _getportbyid(self, id, vhost=''):
        ports = self.managed_ids.get((vhost, id))
        if ports:
            return ports[0]
        else:
            return None

    def waitportbyid(self, id, timeout=30, vhost=''):
        "Wait for port with the specified id. The return value is a pair ``(datapath_id, port)``"
        for m in self._wait_for_sync():
            yield m

        def waitinner():
            p = self._getportbyid(id, vhost)
            if p is None:
                try:
                    self.wait_ids[(vhost, id)] = self.wait_ids.get(
                        (vhost, id), 0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(
                        None, None, None, id, vhost), )
                except:
                    v = self.wait_ids.get((vhost, id))
                    if v is not None:
                        if v <= 1:
                            del self.wait_ids[(vhost, id)]
                        else:
                            self.wait_ids[(vhost, id)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = (
                        self.apiroutine.event.datapathid,
                        self.apiroutine.event.port)
            else:
                self.apiroutine.retvalue = p

        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException(
                'Port ' + repr(id) + ' does not appear before timeout')

    def resync(self, datapathid, vhost=''):
        '''
        Resync with current ports
        '''
        # Sometimes when the OVSDB connection is very busy, monitor message may be dropped.
        # We must deal with this and recover from it
        # Save current manged_ports
        if (vhost, datapathid) not in self.managed_ports:
            self.apiroutine.retvalue = None
            return
        else:
            for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getconnection',
                             {
                                 'datapathid': datapathid,
                                 'vhost': vhost
                             }):
                yield m
            c = self.apiroutine.retvalue
            if c is not None:
                # For now, we restart the connection...
                for m in c.reconnect(False):
                    yield m
                for m in self.apiroutine.waitWithTimeout(0.1):
                    yield m
                for m in callAPI(self.apiroutine, 'ovsdbmanager',
                                 'waitconnection', {
                                     'datapathid': datapathid,
                                     'vhost': vhost
                                 }):
                    yield m
        self.apiroutine.retvalue = None
Exemplo n.º 7
0
class OpenflowManager(Module):
    '''
    Manage Openflow Connections
    '''
    service = True
    _default_vhostbind = None

    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_conns
        self.routines.append(self.apiroutine)
        self.managed_conns = {}
        self.endpoint_conns = {}
        self.table_modules = set()
        self._acquiring = False
        self._acquire_updated = False
        self._lastacquire = None
        self._synchronized = False
        self.createAPI(api(self.getconnections, self.apiroutine),
                       api(self.getconnection, self.apiroutine),
                       api(self.waitconnection, self.apiroutine),
                       api(self.getdatapathids, self.apiroutine),
                       api(self.getalldatapathids, self.apiroutine),
                       api(self.getallconnections, self.apiroutine),
                       api(self.getconnectionsbyendpoint, self.apiroutine),
                       api(self.getconnectionsbyendpointname, self.apiroutine),
                       api(self.getendpoints, self.apiroutine),
                       api(self.getallendpoints, self.apiroutine),
                       api(self.acquiretable, self.apiroutine),
                       api(self.unacquiretable, self.apiroutine),
                       api(self.lastacquiredtables))

    def _add_connection(self, conn):
        vhost = conn.protocol.vhost
        conns = self.managed_conns.setdefault(
            (vhost, conn.openflow_datapathid), [])
        remove = []
        for i in range(0, len(conns)):
            if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid:
                ci = conns[i]
                remove = [ci]
                ep = _get_endpoint(ci)
                econns = self.endpoint_conns.get((vhost, ep))
                if econns is not None:
                    try:
                        econns.remove(ci)
                    except ValueError:
                        pass
                    if not econns:
                        del self.endpoint_conns[(vhost, ep)]
                del conns[i]
                break
        conns.append(conn)
        ep = _get_endpoint(conn)
        econns = self.endpoint_conns.setdefault((vhost, ep), [])
        econns.append(conn)
        if self._lastacquire and conn.openflow_auxiliaryid == 0:
            self.apiroutine.subroutine(self._initialize_connection(conn))
        return remove

    def _initialize_connection(self, conn):
        ofdef = conn.openflowdef
        flow_mod = ofdef.ofp_flow_mod(buffer_id=ofdef.OFP_NO_BUFFER,
                                      out_port=ofdef.OFPP_ANY,
                                      command=ofdef.OFPFC_DELETE)
        if hasattr(ofdef, 'OFPG_ANY'):
            flow_mod.out_group = ofdef.OFPG_ANY
        if hasattr(ofdef, 'OFPTT_ALL'):
            flow_mod.table_id = ofdef.OFPTT_ALL
        if hasattr(ofdef, 'ofp_match_oxm'):
            flow_mod.match = ofdef.ofp_match_oxm()
        cmds = [flow_mod]
        if hasattr(ofdef, 'ofp_group_mod'):
            group_mod = ofdef.ofp_group_mod(command=ofdef.OFPGC_DELETE,
                                            group_id=ofdef.OFPG_ALL)
            cmds.append(group_mod)
        for m in conn.protocol.batch(cmds, conn, self.apiroutine):
            yield m
        if hasattr(ofdef, 'ofp_instruction_goto_table'):
            # Create default flows
            vhost = conn.protocol.vhost
            if self._lastacquire and vhost in self._lastacquire:
                _, pathtable = self._lastacquire[vhost]
                cmds = [
                    ofdef.ofp_flow_mod(table_id=t[i][1],
                                       command=ofdef.OFPFC_ADD,
                                       priority=0,
                                       buffer_id=ofdef.OFP_NO_BUFFER,
                                       out_port=ofdef.OFPP_ANY,
                                       out_group=ofdef.OFPG_ANY,
                                       match=ofdef.ofp_match_oxm(),
                                       instructions=[
                                           ofdef.ofp_instruction_goto_table(
                                               table_id=t[i + 1][1])
                                       ]) for _, t in pathtable.items()
                    for i in range(0,
                                   len(t) - 1)
                ]
                if cmds:
                    for m in conn.protocol.batch(cmds, conn, self.apiroutine):
                        yield m
        for m in self.apiroutine.waitForSend(
                FlowInitialize(conn, conn.openflow_datapathid,
                               conn.protocol.vhost)):
            yield m

    def _acquire_tables(self):
        try:
            while self._acquire_updated:
                result = None
                exception = None
                # Delay the update so we are not updating table acquires for every module
                for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()):
                    yield m
                yield (TableAcquireDelayEvent.createMatcher(), )
                module_list = list(self.table_modules)
                self._acquire_updated = False
                try:
                    for m in self.apiroutine.executeAll(
                        (callAPI(self.apiroutine, module, 'gettablerequest',
                                 {}) for module in module_list)):
                        yield m
                except QuitException:
                    raise
                except Exception as exc:
                    self._logger.exception('Acquiring table failed')
                    exception = exc
                else:
                    requests = [r[0] for r in self.apiroutine.retvalue]
                    vhosts = set(vh for _, vhs in requests if vhs is not None
                                 for vh in vhs)
                    vhost_result = {}
                    # Requests should be list of (name, (ancester, ancester, ...), pathname)
                    for vh in vhosts:
                        graph = {}
                        table_path = {}
                        try:
                            for r in requests:
                                if r[1] is None or vh in r[1]:
                                    for name, ancesters, pathname in r[0]:
                                        if name in table_path:
                                            if table_path[name] != pathname:
                                                raise ValueError(
                                                    "table conflict detected: %r can not be in two path: %r and %r"
                                                    % (name, table_path[name],
                                                       pathname))
                                        else:
                                            table_path[name] = pathname
                                        if name not in graph:
                                            graph[name] = (set(ancesters),
                                                           set())
                                        else:
                                            graph[name][0].update(ancesters)
                                        for anc in ancesters:
                                            graph.setdefault(
                                                anc,
                                                (set(), set()))[1].add(name)
                        except ValueError as exc:
                            self._logger.error(str(exc))
                            exception = exc
                            break
                        else:
                            sequences = []

                            def dfs_sort(current):
                                sequences.append(current)
                                for d in graph[current][1]:
                                    anc = graph[d][0]
                                    anc.remove(current)
                                    if not anc:
                                        dfs_sort(d)

                            nopre_tables = sorted(
                                [k for k, v in graph.items() if not v[0]],
                                key=lambda x: (table_path.get(name, ''), name))
                            for t in nopre_tables:
                                dfs_sort(t)
                            if len(sequences) < len(graph):
                                rest_tables = set(
                                    graph.keys()).difference(sequences)
                                self._logger.error(
                                    "Circle detected in table acquiring, following tables are related: %r, vhost = %r",
                                    sorted(rest_tables), vh)
                                self._logger.error(
                                    "Circle dependencies are: %s", ", ".join(
                                        repr(tuple(graph[t][0])) + "=>" + t
                                        for t in rest_tables))
                                exception = ValueError(
                                    "Circle detected in table acquiring, following tables are related: %r, vhost = %r"
                                    % (sorted(rest_tables), vh))
                                break
                            elif len(sequences) > 255:
                                self._logger.error(
                                    "Table limit exceeded: %d tables (only 255 allowed), vhost = %r",
                                    len(sequences), vh)
                                exception = ValueError(
                                    "Table limit exceeded: %d tables (only 255 allowed), vhost = %r"
                                    % (len(sequences), vh))
                                break
                            else:
                                full_indices = list(
                                    zip(sequences, itertools.count()))
                                tables = dict(
                                    (k, tuple(g))
                                    for k, g in itertools.groupby(
                                        sorted(full_indices,
                                               key=lambda x: table_path.get(
                                                   x[0], '')),
                                        lambda x: table_path.get(x[0], '')))
                                vhost_result[vh] = (full_indices, tables)
        finally:
            self._acquiring = False
        if exception:
            for m in self.apiroutine.waitForSend(
                    TableAcquireUpdate(exception=exception)):
                yield m
        else:
            result = vhost_result
            if result != self._lastacquire:
                self._lastacquire = result
                self._reinitall()
            for m in self.apiroutine.waitForSend(
                    TableAcquireUpdate(result=result)):
                yield m

    def load(self, container):
        self.scheduler.queue.addSubQueue(
            1, TableAcquireDelayEvent.createMatcher(),
            'ofpmanager_tableacquiredelay')
        for m in container.waitForSend(TableAcquireUpdate(result=None)):
            yield m
        for m in Module.load(self, container):
            yield m

    def unload(self, container, force=False):
        for m in Module.unload(self, container, force=force):
            yield m
        for m in container.syscall(
                syscall_removequeue(self.scheduler.queue,
                                    'ofpmanager_tableacquiredelay')):
            yield m

    def _reinitall(self):
        for cl in self.managed_conns.values():
            for c in cl:
                self.apiroutine.subroutine(self._initialize_connection(c))

    def _manage_existing(self):
        for m in callAPI(self.apiroutine, "openflowserver", "getconnections",
                         {}):
            yield m
        vb = self.vhostbind
        for c in self.apiroutine.retvalue:
            if vb is None or c.protocol.vhost in vb:
                self._add_connection(c)
        self._synchronized = True
        for m in self.apiroutine.waitForSend(
                ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m

    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(),
                                                    'synchronized'), )

    def _manage_conns(self):
        vb = self.vhostbind
        self.apiroutine.subroutine(self._manage_existing(), False)
        try:
            if vb is not None:
                conn_up = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_SETUP,
                    _ismatch=lambda x: x.createby.vhost in vb)
                conn_down = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_DOWN,
                    _ismatch=lambda x: x.createby.vhost in vb)
            else:
                conn_up = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_SETUP)
                conn_down = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_DOWN)
            while True:
                yield (conn_up, conn_down)
                if self.apiroutine.matcher is conn_up:
                    e = self.apiroutine.event
                    remove = self._add_connection(e.connection)
                    self.scheduler.emergesend(
                        ModuleNotification(self.getServiceName(),
                                           'update',
                                           add=[e.connection],
                                           remove=remove))
                else:
                    e = self.apiroutine.event
                    conns = self.managed_conns.get(
                        (e.createby.vhost, e.datapathid))
                    remove = []
                    if conns is not None:
                        try:
                            conns.remove(e.connection)
                        except ValueError:
                            pass
                        else:
                            remove.append(e.connection)

                        if not conns:
                            del self.managed_conns[(e.createby.vhost,
                                                    e.datapathid)]
                        # Also delete from endpoint_conns
                        ep = _get_endpoint(e.connection)
                        econns = self.endpoint_conns.get(
                            (e.createby.vhost, ep))
                        if econns is not None:
                            try:
                                econns.remove(e.connection)
                            except ValueError:
                                pass
                            if not econns:
                                del self.endpoint_conns[(e.createby.vhost, ep)]
                    if remove:
                        self.scheduler.emergesend(
                            ModuleNotification(self.getServiceName(),
                                               'update',
                                               add=[],
                                               remove=remove))
        finally:
            self.scheduler.emergesend(
                ModuleNotification(self.getServiceName(), 'unsynchronized'))

    def getconnections(self, datapathid, vhost=''):
        "Return all connections of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(
            self.managed_conns.get((vhost, datapathid), []))

    def getconnection(self, datapathid, auxiliaryid=0, vhost=''):
        "Get current connection of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid,
                                                       vhost)

    def _getconnection(self, datapathid, auxiliaryid=0, vhost=''):
        conns = self.managed_conns.get((vhost, datapathid))
        if conns is None:
            return None
        else:
            for c in conns:
                if c.openflow_auxiliaryid == auxiliaryid:
                    return c
            return None

    def waitconnection(self, datapathid, auxiliaryid=0, timeout=30, vhost=''):
        "Wait for a datapath connection"
        for m in self._wait_for_sync():
            yield m
        c = self._getconnection(datapathid, auxiliaryid, vhost)
        if c is None:
            for m in self.apiroutine.waitWithTimeout(
                    timeout,
                    OpenflowConnectionStateEvent.createMatcher(
                        datapathid,
                        auxiliaryid,
                        OpenflowConnectionStateEvent.CONNECTION_SETUP,
                        _ismatch=lambda x: x.createby.vhost == vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException(
                    'Datapath %016x is not connected' % datapathid)
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c

    def getdatapathids(self, vhost=''):
        "Get All datapath IDs"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            k[1] for k in self.managed_conns.keys() if k[0] == vhost
        ]

    def getalldatapathids(self):
        "Get all datapath IDs from any vhost. Return ``(vhost, datapathid)`` pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.keys())

    def getallconnections(self, vhost=''):
        "Get all connections from vhost. If vhost is None, return all connections from any host"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = list(
                itertools.chain(self.managed_conns.values()))
        else:
            self.apiroutine.retvalue = list(
                itertools.chain(v for k, v in self.managed_conns.items()
                                if k[0] == vhost))

    def getconnectionsbyendpoint(self, endpoint, vhost=''):
        "Get connection by endpoint address (IP, IPv6 or UNIX socket address)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint))

    def getconnectionsbyendpointname(self, name, vhost='', timeout=30):
        "Get connection by endpoint name (Domain name, IP or IPv6 address)"
        # Resolve the name
        if not name:
            endpoint = ''
            for m in self.getconnectionbyendpoint(endpoint, vhost):
                yield m
        else:
            request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM,
                       socket.IPPROTO_TCP,
                       socket.AI_ADDRCONFIG | socket.AI_V4MAPPED)
            # Resolve hostname
            for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)):
                yield m
            for m in self.apiroutine.waitWithTimeout(
                    timeout, ResolveResponseEvent.createMatcher(request)):
                yield m
            if self.apiroutine.timeout:
                # Resolve is only allowed through asynchronous resolver
                #try:
                #    self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST)
                #except:
                raise IOError('Resolve hostname timeout: ' + name)
            else:
                if hasattr(self.apiroutine.event, 'error'):
                    raise IOError('Cannot resolve hostname: ' + name)
                resp = self.apiroutine.event.response
                for r in resp:
                    raddr = r[4]
                    if isinstance(raddr, tuple):
                        # Ignore port
                        endpoint = raddr[0]
                    else:
                        # Unix socket? This should not happen, but in case...
                        endpoint = raddr
                    for m in self.getconnectionsbyendpoint(endpoint, vhost):
                        yield m
                    if self.apiroutine.retvalue is not None:
                        break

    def getendpoints(self, vhost=''):
        "Get all endpoints for vhost"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            k[1] for k in self.endpoint_conns if k[0] == vhost
        ]

    def getallendpoints(self):
        "Get all endpoints from any vhost. Return ``(vhost, endpoint)`` pairs."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.endpoint_conns.keys())

    def lastacquiredtables(self, vhost=""):
        "Get acquired table IDs"
        return self._lastacquire.get(vhost)

    def acquiretable(self, modulename):
        "Start to acquire tables for a module on module loading."
        if not modulename in self.table_modules:
            self.table_modules.add(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield

    def unacquiretable(self, modulename):
        "When module is unloaded, stop acquiring tables for this module."
        if modulename in self.table_modules:
            self.table_modules.remove(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield
Exemplo n.º 8
0
class OpenflowManager(Module):
    '''
    Manage Openflow Connections
    '''
    service = True
    _default_vhostbind = None
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_conns
        self.routines.append(self.apiroutine)
        self.managed_conns = {}
        self.endpoint_conns = {}
        self.table_modules = set()
        self._acquiring = False
        self._acquire_updated = False
        self._lastacquire = None
        self._synchronized = False
        self.createAPI(api(self.getconnections, self.apiroutine),
                       api(self.getconnection, self.apiroutine),
                       api(self.waitconnection, self.apiroutine),
                       api(self.getdatapathids, self.apiroutine),
                       api(self.getalldatapathids, self.apiroutine),
                       api(self.getallconnections, self.apiroutine),
                       api(self.getconnectionsbyendpoint, self.apiroutine),
                       api(self.getconnectionsbyendpointname, self.apiroutine),
                       api(self.getendpoints, self.apiroutine),
                       api(self.getallendpoints, self.apiroutine),
                       api(self.acquiretable, self.apiroutine),
                       api(self.unacquiretable, self.apiroutine),
                       api(self.lastacquiredtables)
                       )
    def _add_connection(self, conn):
        vhost = conn.protocol.vhost
        conns = self.managed_conns.setdefault((vhost, conn.openflow_datapathid), [])
        remove = []
        for i in range(0, len(conns)):
            if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid:
                ci = conns[i]
                remove = [ci]
                ep = _get_endpoint(ci)
                econns = self.endpoint_conns.get((vhost, ep))
                if econns is not None:
                    try:
                        econns.remove(ci)
                    except ValueError:
                        pass
                    if not econns:
                        del self.endpoint_conns[(vhost, ep)]
                del conns[i]
                break
        conns.append(conn)
        ep = _get_endpoint(conn)
        econns = self.endpoint_conns.setdefault((vhost, ep), [])
        econns.append(conn)
        if self._lastacquire and conn.openflow_auxiliaryid == 0:
            self.apiroutine.subroutine(self._initialize_connection(conn))
        return remove
    def _initialize_connection(self, conn):
        ofdef = conn.openflowdef
        flow_mod = ofdef.ofp_flow_mod(buffer_id = ofdef.OFP_NO_BUFFER,
                                                 out_port = ofdef.OFPP_ANY,
                                                 command = ofdef.OFPFC_DELETE
                                                 )
        if hasattr(ofdef, 'OFPG_ANY'):
            flow_mod.out_group = ofdef.OFPG_ANY
        if hasattr(ofdef, 'OFPTT_ALL'):
            flow_mod.table_id = ofdef.OFPTT_ALL
        if hasattr(ofdef, 'ofp_match_oxm'):
            flow_mod.match = ofdef.ofp_match_oxm()
        cmds = [flow_mod]
        if hasattr(ofdef, 'ofp_group_mod'):
            group_mod = ofdef.ofp_group_mod(command = ofdef.OFPGC_DELETE,
                                            group_id = ofdef.OFPG_ALL
                                            )
            cmds.append(group_mod)
        for m in conn.protocol.batch(cmds, conn, self.apiroutine):
            yield m
        if hasattr(ofdef, 'ofp_instruction_goto_table'):
            # Create default flows
            vhost = conn.protocol.vhost
            if self._lastacquire and vhost in self._lastacquire:
                _, pathtable = self._lastacquire[vhost]
                cmds = [ofdef.ofp_flow_mod(table_id = t[i][1],
                                             command = ofdef.OFPFC_ADD,
                                             priority = 0,
                                             buffer_id = ofdef.OFP_NO_BUFFER,
                                             out_port = ofdef.OFPP_ANY,
                                             out_group = ofdef.OFPG_ANY,
                                             match = ofdef.ofp_match_oxm(),
                                             instructions = [ofdef.ofp_instruction_goto_table(table_id = t[i+1][1])]
                                       )
                          for _,t in pathtable.items()
                          for i in range(0, len(t) - 1)]
                if cmds:
                    for m in conn.protocol.batch(cmds, conn, self.apiroutine):
                        yield m
        for m in self.apiroutine.waitForSend(FlowInitialize(conn, conn.openflow_datapathid, conn.protocol.vhost)):
            yield m
    def _acquire_tables(self):
        try:
            while self._acquire_updated:
                result = None
                exception = None
                # Delay the update so we are not updating table acquires for every module
                for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()):
                    yield m
                yield (TableAcquireDelayEvent.createMatcher(),)
                module_list = list(self.table_modules)
                self._acquire_updated = False
                try:
                    for m in self.apiroutine.executeAll((callAPI(self.apiroutine, module, 'gettablerequest', {}) for module in module_list)):
                        yield m
                except QuitException:
                    raise
                except Exception as exc:
                    self._logger.exception('Acquiring table failed')
                    exception = exc
                else:
                    requests = [r[0] for r in self.apiroutine.retvalue]
                    vhosts = set(vh for _, vhs in requests if vhs is not None for vh in vhs)
                    vhost_result = {}
                    # Requests should be list of (name, (ancester, ancester, ...), pathname)
                    for vh in vhosts:
                        graph = {}
                        table_path = {}
                        try:
                            for r in requests:
                                if r[1] is None or vh in r[1]:
                                    for name, ancesters, pathname in r[0]:
                                        if name in table_path:
                                            if table_path[name] != pathname:
                                                raise ValueError("table conflict detected: %r can not be in two path: %r and %r" % (name, table_path[name], pathname))
                                        else:
                                            table_path[name] = pathname
                                        if name not in graph:
                                            graph[name] = (set(ancesters), set())
                                        else:
                                            graph[name][0].update(ancesters)
                                        for anc in ancesters:
                                            graph.setdefault(anc, (set(), set()))[1].add(name)
                        except ValueError as exc:
                            self._logger.error(str(exc))
                            exception = exc
                            break
                        else:
                            sequences = []
                            def dfs_sort(current):
                                sequences.append(current)
                                for d in graph[current][1]:
                                    anc = graph[d][0]
                                    anc.remove(current)
                                    if not anc:
                                        dfs_sort(d)
                            nopre_tables = sorted([k for k,v in graph.items() if not v[0]], key = lambda x: (table_path.get(name, ''),name))
                            for t in nopre_tables:
                                dfs_sort(t)
                            if len(sequences) < len(graph):
                                rest_tables = set(graph.keys()).difference(sequences)
                                self._logger.error("Circle detected in table acquiring, following tables are related: %r, vhost = %r", sorted(rest_tables), vh)
                                self._logger.error("Circle dependencies are: %s", ", ".join(repr(tuple(graph[t][0])) + "=>" + t for t in rest_tables))
                                exception = ValueError("Circle detected in table acquiring, following tables are related: %r, vhost = %r" % (sorted(rest_tables),vh))
                                break
                            elif len(sequences) > 255:
                                self._logger.error("Table limit exceeded: %d tables (only 255 allowed), vhost = %r", len(sequences), vh)
                                exception = ValueError("Table limit exceeded: %d tables (only 255 allowed), vhost = %r" % (len(sequences),vh))
                                break
                            else:
                                full_indices = list(zip(sequences, itertools.count()))
                                tables = dict((k,tuple(g)) for k,g in itertools.groupby(sorted(full_indices, key = lambda x: table_path.get(x[0], '')),
                                                           lambda x: table_path.get(x[0], '')))
                                vhost_result[vh] = (full_indices, tables)
        finally:
            self._acquiring = False
        if exception:
            for m in self.apiroutine.waitForSend(TableAcquireUpdate(exception = exception)):
                yield m
        else:
            result = vhost_result
            if result != self._lastacquire:
                self._lastacquire = result
                self._reinitall()
            for m in self.apiroutine.waitForSend(TableAcquireUpdate(result = result)):
                yield m
    def load(self, container):
        self.scheduler.queue.addSubQueue(1, TableAcquireDelayEvent.createMatcher(), 'ofpmanager_tableacquiredelay')
        for m in container.waitForSend(TableAcquireUpdate(result = None)):
            yield m
        for m in Module.load(self, container):
            yield m
    def unload(self, container, force=False):
        for m in Module.unload(self, container, force=force):
            yield m
        for m in container.syscall(syscall_removequeue(self.scheduler.queue, 'ofpmanager_tableacquiredelay')):
            yield m
    def _reinitall(self):
        for cl in self.managed_conns.values():
            for c in cl:
                self.apiroutine.subroutine(self._initialize_connection(c))
    def _manage_existing(self):
        for m in callAPI(self.apiroutine, "openflowserver", "getconnections", {}):
            yield m
        vb = self.vhostbind
        for c in self.apiroutine.retvalue:
            if vb is None or c.protocol.vhost in vb:
                self._add_connection(c)
        self._synchronized = True
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m
    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),)
    def _manage_conns(self):
        vb = self.vhostbind
        self.apiroutine.subroutine(self._manage_existing(), False)
        try:
            if vb is not None:
                conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
                conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
            else:
                conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP)
                conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN)
            while True:
                yield (conn_up, conn_down)
                if self.apiroutine.matcher is conn_up:
                    e = self.apiroutine.event
                    remove = self._add_connection(e.connection)
                    self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [e.connection], remove = remove))
                else:
                    e = self.apiroutine.event
                    conns = self.managed_conns.get((e.createby.vhost, e.datapathid))
                    remove = []
                    if conns is not None:
                        try:
                            conns.remove(e.connection)
                        except ValueError:
                            pass
                        else:
                            remove.append(e.connection)
                        
                        if not conns:
                            del self.managed_conns[(e.createby.vhost, e.datapathid)]
                        # Also delete from endpoint_conns
                        ep = _get_endpoint(e.connection)
                        econns = self.endpoint_conns.get((e.createby.vhost, ep))
                        if econns is not None:
                            try:
                                econns.remove(e.connection)
                            except ValueError:
                                pass
                            if not econns:
                                del self.endpoint_conns[(e.createby.vhost, ep)]
                    if remove:
                        self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [], remove = remove))
        finally:
            self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized'))
    def getconnections(self, datapathid, vhost = ''):
        "Return all connections of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.get((vhost, datapathid), []))
    def getconnection(self, datapathid, auxiliaryid = 0, vhost = ''):
        "Get current connection of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid, vhost)
    def _getconnection(self, datapathid, auxiliaryid = 0, vhost = ''):
        conns = self.managed_conns.get((vhost, datapathid))
        if conns is None:
            return None
        else:
            for c in conns:
                if c.openflow_auxiliaryid == auxiliaryid:
                    return c
            return None
    def waitconnection(self, datapathid, auxiliaryid = 0, timeout = 30, vhost = ''):
        "Wait for a datapath connection"
        for m in self._wait_for_sync():
            yield m
        c = self._getconnection(datapathid, auxiliaryid, vhost)
        if c is None:
            for m in self.apiroutine.waitWithTimeout(timeout, 
                            OpenflowConnectionStateEvent.createMatcher(datapathid, auxiliaryid,
                                    OpenflowConnectionStateEvent.CONNECTION_SETUP,
                                    _ismatch = lambda x: x.createby.vhost == vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath %016x is not connected' % datapathid)
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c
    def getdatapathids(self, vhost = ''):
        "Get All datapath IDs"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost]
    def getalldatapathids(self):
        "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.keys())
    def getallconnections(self, vhost = ''):
        "Get all connections from vhost. If vhost is None, return all connections from any host"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = list(itertools.chain(self.managed_conns.values()))
        else:
            self.apiroutine.retvalue = list(itertools.chain(v for k,v in self.managed_conns.items() if k[0] == vhost))
    def getconnectionsbyendpoint(self, endpoint, vhost = ''):
        "Get connection by endpoint address (IP, IPv6 or UNIX socket address)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint))
    def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30):
        "Get connection by endpoint name (Domain name, IP or IPv6 address)"
        # Resolve the name
        if not name:
            endpoint = ''
            for m in self.getconnectionbyendpoint(endpoint, vhost):
                yield m
        else:
            request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED)
            # Resolve hostname
            for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)):
                yield m
            for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)):
                yield m
            if self.apiroutine.timeout:
                # Resolve is only allowed through asynchronous resolver
                #try:
                #    self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST)
                #except:
                raise IOError('Resolve hostname timeout: ' + name)
            else:
                if hasattr(self.apiroutine.event, 'error'):
                    raise IOError('Cannot resolve hostname: ' + name)
                resp = self.apiroutine.event.response
                for r in resp:
                    raddr = r[4]
                    if isinstance(raddr, tuple):
                        # Ignore port
                        endpoint = raddr[0]
                    else:
                        # Unix socket? This should not happen, but in case...
                        endpoint = raddr
                    for m in self.getconnectionsbyendpoint(endpoint, vhost):
                        yield m
                    if self.apiroutine.retvalue is not None:
                        break
    def getendpoints(self, vhost = ''):
        "Get all endpoints for vhost"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost]
    def getallendpoints(self):
        "Get all endpoints from any vhost. Return (vhost, endpoint) pairs."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.endpoint_conns.keys())
    def lastacquiredtables(self, vhost = ""):
        "Get acquired table IDs"
        return self._lastacquire.get(vhost)
    def acquiretable(self, modulename):
        "Start to acquire tables for a module on module loading."
        if not modulename in self.table_modules:
            self.table_modules.add(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield
    def unacquiretable(self, modulename):
        "When module is unloaded, stop acquiring tables for this module."
        if modulename in self.table_modules:
            self.table_modules.remove(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield
Exemplo n.º 9
0
class TestObjectDB(Module):
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._main
        self.routines.append(self.apiroutine)
        self._reqid = 0
        self._ownerid = uuid1().hex
        self.createAPI(api(self.createlogicalnetwork, self.apiroutine),
                       api(self.createlogicalnetworks, self.apiroutine),
                       api(self.createphysicalnetwork, self.apiroutine),
                       api(self.createphysicalnetworks, self.apiroutine),
                       api(self.createphysicalport, self.apiroutine),
                       api(self.createphysicalports, self.apiroutine),
                       api(self.createlogicalport, self.apiroutine),
                       api(self.createlogicalports, self.apiroutine),
                       api(self.getlogicalnetworks, self.apiroutine))
        self._logger.setLevel(logging.DEBUG)
    def _monitor(self):
        update_event = DataObjectUpdateEvent.createMatcher()
        while True:
            yield (update_event,)
            self._logger.info('Database update: %r', self.apiroutine.event)
    def _dumpkeys(self, keys):
        self._reqid += 1
        reqid = ('testobjectdb', self._reqid)
        for m in callAPI(self.apiroutine, 'objectdb', 'mget', {'keys': keys, 'requestid': reqid}):
            yield m
        retobjs = self.apiroutine.retvalue
        with watch_context(keys, retobjs, reqid, self.apiroutine):
            self.apiroutine.retvalue = [dump(v) for v in retobjs]
    def _updateport(self, key):
        unload_matcher = ModuleLoadStateChanged.createMatcher(self.target, ModuleLoadStateChanged.UNLOADING)
        def updateinner():
            self._reqid += 1
            reqid = ('testobjectdb', self._reqid)
            for m in callAPI(self.apiroutine, 'objectdb', 'get', {'key': key, 'requestid': reqid}):
                yield m
            portobj = self.apiroutine.retvalue
            with watch_context([key], [portobj], reqid, self.apiroutine):
                if portobj is not None:
                    @updater
                    def write_status(portobj):
                        if portobj is None:
                            raise ValueError('Already deleted')
                        if not hasattr(portobj, 'owner'):
                            portobj.owner = self._ownerid
                            portobj.status = 'READY'
                            return [portobj]
                        else:
                            raise ValueError('Already managed')
                    try:
                        for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [portobj.getkey()], 'updater': write_status}):
                            yield m
                    except ValueError:
                        pass
                    else:
                        for m in portobj.waitif(self.apiroutine, lambda x: x.isdeleted() or hasattr(x, 'owner')):
                            yield m
                        self._logger.info('Port managed: %r', dump(portobj))
                        while True:
                            for m in portobj.waitif(self.apiroutine, lambda x: True, True):
                                yield m
                            if portobj.isdeleted():
                                self._logger.info('Port deleted: %r', dump(portobj))
                                break
                            else:
                                self._logger.info('Port updated: %r', dump(portobj))
        try:
            for m in self.apiroutine.withException(updateinner(), unload_matcher):
                yield m
        except RoutineException:
            pass
    def _waitforchange(self, key):
        for m in callAPI(self.apiroutine, 'objectdb', 'watch', {'key': key, 'requestid': 'testobjectdb'}):
            yield m
        setobj = self.apiroutine.retvalue
        with watch_context([key], [setobj], 'testobjectdb', self.apiroutine):
            for m in setobj.wait(self.apiroutine):
                yield m
            oldset = set()
            while True:
                for weakref in setobj.set.dataset().difference(oldset):
                    self.apiroutine.subroutine(self._updateport(weakref.getkey()))
                oldset = set(setobj.set.dataset())
                for m in setobj.waitif(self.apiroutine, lambda x: not x.isdeleted(), True):
                    yield m
    def _main(self):
        routines = []
        routines.append(self._monitor())
        keys = [LogicalPortSet.default_key(), PhysicalPortSet.default_key()]
        for k in keys:
            routines.append(self._waitforchange(k))
        for m in self.apiroutine.executeAll(routines, retnames = ()):
            yield m
    def load(self, container):
        @updater
        def initialize(phynetset, lognetset, logportset, phyportset):
            if phynetset is None:
                phynetset = PhysicalNetworkSet()
                phynetset.set = DataObjectSet()
            if lognetset is None:
                lognetset = LogicalNetworkSet()
                lognetset.set = DataObjectSet()
            if logportset is None:
                logportset = LogicalPortSet()
                logportset.set = DataObjectSet()
            if phyportset is None:
                phyportset = PhysicalPortSet()
                phyportset.set = DataObjectSet()
            return [phynetset, lognetset, logportset, phyportset]
        for m in callAPI(container, 'objectdb', 'transact', {'keys':[PhysicalNetworkSet.default_key(),
                                                                   LogicalNetworkSet.default_key(),
                                                                   LogicalPortSet.default_key(),
                                                                   PhysicalPortSet.default_key()],
                                                             'updater': initialize}):
            yield m
        for m in Module.load(self, container):
            yield m
    def createphysicalnetwork(self, type = 'vlan', id = None, **kwargs):
        new_network, new_map = self._createphysicalnetwork(type, id, **kwargs)
        @updater
        def create_phy(physet, phynet, phymap):
            phynet = set_new(phynet, new_network)
            phymap = set_new(phymap, new_map)
            physet.set.dataset().add(phynet.create_weakreference())
            return [physet, phynet, phymap]
        for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys':[PhysicalNetworkSet.default_key(),
                                                                           new_network.getkey(),
                                                                           new_map.getkey()],'updater':create_phy}):
            yield m
        for m in self._dumpkeys([new_network.getkey()]):
            yield m
        self.apiroutine.retvalue = self.apiroutine.retvalue[0]
    def createphysicalnetworks(self, networks):
        new_networks = [self._createphysicalnetwork(**n) for n in networks]
        @updater
        def create_phys(physet, *phynets):
            return_nets = [None, None] * len(new_networks)
            for i in range(0, len(new_networks)):
                return_nets[i * 2] = set_new(phynets[i * 2], new_networks[i][0])
                return_nets[i * 2 + 1] = set_new(phynets[i * 2 + 1], new_networks[i][1])
                physet.set.dataset().add(new_networks[i][0].create_weakreference())
            return [physet] + return_nets
        keys = [sn.getkey() for n in new_networks for sn in n]
        for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys':[PhysicalNetworkSet.default_key()] + keys,'updater':create_phys}):
            yield m
        for m in self._dumpkeys([n[0].getkey() for n in new_networks]):
            yield m
    def _createlogicalnetwork(self, physicalnetwork, id = None, **kwargs):
        if not id:
            id = str(uuid1())
        new_network = LogicalNetwork.create_instance(id)
        for k,v in kwargs.items():
            setattr(new_network, k, v)
        new_network.physicalnetwork = ReferenceObject(PhysicalNetwork.default_key(physicalnetwork))
        new_networkmap = LogicalNetworkMap.create_instance(id)
        new_networkmap.network = new_network.create_reference()
        return new_network,new_networkmap
    def createlogicalnetworks(self, networks):
        new_networks = [self._createlogicalnetwork(**n) for n in networks]
        physical_networks = list(set(n[0].physicalnetwork.getkey() for n in new_networks))
        physical_maps = [PhysicalNetworkMap.default_key(PhysicalNetwork._getIndices(k)[1][0]) for k in physical_networks]
        @updater
        def create_logs(logset, *networks):
            phy_maps = list(networks[len(new_networks) * 2 : len(new_networks) * 2 + len(physical_networks)])
            phy_nets = list(networks[len(new_networks) * 2 + len(physical_networks):])
            phy_dict = dict(zip(physical_networks, zip(phy_nets, phy_maps)))
            return_nets = [None, None] * len(new_networks)
            for i in range(0, len(new_networks)):
                return_nets[2 * i] = set_new(networks[2 * i], new_networks[i][0])
                return_nets[2 * i + 1] = set_new(networks[2 * i + 1], new_networks[i][1])
            for n in return_nets[::2]:
                phynet, phymap = phy_dict.get(n.physicalnetwork.getkey())
                if phynet is None:
                    _, (phyid,) = PhysicalNetwork._getIndices(n.physicalnetwork.getkey())
                    raise ValueError('Physical network %r does not exist' % (phyid,))
                else:
                    if phynet.type == 'vlan':
                        if hasattr(n, 'vlanid'):
                            n.vlanid = int(n.vlanid)
                            if n.vlanid <= 0 or n.vlanid >= 4095:
                                raise ValueError('Invalid VLAN ID')
                            # VLAN id is specified
                            if str(n.vlanid) in phymap.network_allocation:
                                raise ValueError('VLAN ID %r is already allocated in physical network %r' % (n.vlanid,phynet.id))
                            else:
                                for start,end in phynet.vlanrange:
                                    if start <= n.vlanid <= end:
                                        break
                                else:
                                    raise ValueError('VLAN ID %r is not in vlan range of physical network %r' % (n.vlanid,phynet.id))
                            phymap.network_allocation[str(n.vlanid)] = n.create_weakreference()
                        else:
                            # Allocate a new VLAN id
                            for start,end in phynet.vlanrange:
                                for vlanid in range(start, end + 1):
                                    if str(vlanid) not in phymap.network_allocation:
                                        break
                                else:
                                    continue
                                break
                            else:
                                raise ValueError('Not enough VLAN ID to be allocated in physical network %r' % (phynet.id,))
                            n.vlanid = vlanid
                            phymap.network_allocation[str(vlanid)] = n.create_weakreference()
                    else:
                        if phymap.network_allocation:
                            raise ValueError('Physical network %r is already allocated by another logical network', (phynet.id,))
                        phymap.network_allocation['native'] = n.create_weakreference()
                    phymap.networks.dataset().add(n.create_weakreference())
                logset.set.dataset().add(n.create_weakreference())
            return [logset] + return_nets + phy_maps
        for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [LogicalNetworkSet.default_key()] +\
                                                                            [sn.getkey() for n in new_networks for sn in n] +\
                                                                            physical_maps +\
                                                                            physical_networks,
                                                                   'updater': create_logs}):
            yield m
        for m in self._dumpkeys([n[0].getkey() for n in new_networks]):
            yield m
    def createlogicalnetwork(self, physicalnetwork, id = None, **kwargs):
        n = {'physicalnetwork':physicalnetwork, 'id':id}
        n.update(kwargs)
        for m in self.createlogicalnetworks([n]):
            yield m
        self.apiroutine.retvalue = self.apiroutine.retvalue[0]
    def _createphysicalnetwork(self, type = 'vlan', id = None, **kwargs):
        if not id:
            id = str(uuid1())
        if type == 'vlan':
            if 'vlanrange' not in kwargs:
                raise ValueError(r'Must specify vlanrange with network type="vlan"')
            vlanrange = kwargs['vlanrange']
            # Check
            try:
                lastend = 0
                for start, end in vlanrange:
                    if start <= lastend:
                        raise ValueError('VLAN sequences overlapped or disordered')
                    lastend = end
                if lastend >= 4095:
                    raise ValueError('VLAN ID out of range')
            except Exception as exc:
                raise ValueError('vlanrange format error: %s' % (str(exc),))
        else:
            type = 'native'
        new_network = PhysicalNetwork.create_instance(id)
        new_network.type = type
        for k,v in kwargs.items():
            setattr(new_network, k, v)
        new_networkmap = PhysicalNetworkMap.create_instance(id)
        new_networkmap.network = new_network.create_reference()
        return (new_network, new_networkmap)
    def createphysicalport(self, physicalnetwork, name, systemid = '%', bridge = '%', **kwargs):
        p = {'physicalnetwork':physicalnetwork, 'name':name, 'systemid':systemid,'bridge':bridge}
        p.update(kwargs)
        for m in self.createphysicalports([p]):
            yield m
        self.apiroutine.retvalue = self.apiroutine.retvalue[0]
    def _createphysicalport(self, physicalnetwork, name, systemid = '%', bridge = '%', **kwargs):
        new_port = PhysicalPort.create_instance(systemid, bridge, name)
        new_port.physicalnetwork = ReferenceObject(PhysicalNetwork.default_key(physicalnetwork))
        for k,v in kwargs.items():
            setattr(new_port, k, v)
        return new_port
    def createphysicalports(self, ports):
        new_ports = [self._createphysicalport(**p) for p in ports]
        physical_networks = list(set([p.physicalnetwork.getkey() for p in new_ports]))
        physical_maps = [PhysicalNetworkMap.default_key(*PhysicalNetwork._getIndices(k)[1]) for k in physical_networks]
        @updater
        def create_ports(portset, *objs):
            old_ports = objs[:len(new_ports)]
            phymaps = list(objs[len(new_ports):len(new_ports) + len(physical_networks)])
            phynets = list(objs[len(new_ports) + len(physical_networks):])
            phydict = dict(zip(physical_networks, zip(phynets, phymaps)))
            return_ports = [None] * len(new_ports)
            for i in range(0, len(new_ports)):
                return_ports[i] = set_new(old_ports[i], new_ports[i])
            for p in return_ports:
                phynet, phymap = phydict[p.physicalnetwork.getkey()]
                if phynet is None:
                    _, (phyid,) = PhysicalNetwork._getIndices(p.physicalnetwork.getkey())
                    raise ValueError('Physical network %r does not exist' % (phyid,))
                phymap.ports.dataset().add(p.create_weakreference())
            portset.set.dataset().add(p.create_weakreference())
            return [portset] + return_ports + phymaps
        for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [PhysicalPortSet.default_key()] +\
                                                                            [p.getkey() for p in new_ports] +\
                                                                            physical_maps +\
                                                                            physical_networks,
                                                                   'updater': create_ports}):
            yield m
        for m in self._dumpkeys([p.getkey() for p in new_ports]):
            yield m
    def createlogicalport(self, logicalnetwork, id = None, **kwargs):
        p = {'logicalnetwork':logicalnetwork, 'id':id}
        p.update(kwargs)
        for m in self.createlogicalports([p]):
            yield m
        self.apiroutine.retvalue = self.apiroutine.retvalue[0]
    def _createlogicalport(self, logicalnetwork, id = None, **kwargs):
        if not id:
            id = str(uuid1())
        new_port = LogicalPort.create_instance(id)
        new_port.logicalnetwork = ReferenceObject(LogicalNetwork.default_key(logicalnetwork))
        for k,v in kwargs.items():
            setattr(new_port, k, v)
        return new_port
    def createlogicalports(self, ports):
        new_ports = [self._createlogicalport(**p) for p in ports]
        logical_networks = list(set([p.logicalnetwork.getkey() for p in new_ports]))
        logical_maps = [LogicalNetworkMap.default_key(*LogicalNetwork._getIndices(k)[1]) for k in logical_networks]
        @updater
        def create_ports(portset, *objs):
            old_ports = objs[:len(new_ports)]
            logmaps = list(objs[len(new_ports):len(new_ports) + len(logical_networks)])
            lognets = list(objs[len(new_ports) + len(logical_networks):])
            logdict = dict(zip(logical_networks, zip(lognets, logmaps)))
            return_ports = [None] * len(new_ports)
            for i in range(0, len(new_ports)):
                return_ports[i] = set_new(old_ports[i], new_ports[i])
            for p in return_ports:
                lognet, logmap = logdict[p.logicalnetwork.getkey()]
                if lognet is None:
                    _, (logid,) = LogicalNetwork._getIndices(p.logicalnetwork.getkey())
                    raise ValueError('Logical network %r does not exist' % (logid,))
                logmap.ports.dataset().add(p.create_weakreference())
            portset.set.dataset().add(p.create_weakreference())
            return [portset] + return_ports + logmaps
        for m in callAPI(self.apiroutine, 'objectdb', 'transact', {'keys': [LogicalPortSet.default_key()] +\
                                                                            [p.getkey() for p in new_ports] +\
                                                                            logical_maps +\
                                                                            logical_networks,
                                                                   'updater': create_ports}):
            yield m
        for m in self._dumpkeys([p.getkey() for p in new_ports]):
            yield m
    def getlogicalnetworks(self, id = None, physicalnetwork = None, **kwargs):
        def set_walker(key, set, walk, save):
            if set is None:
                return
            for o in set.dataset():
                key = o.getkey()
                try:
                    net = walk(key)
                except KeyError:
                    pass
                else:
                    for k,v in kwargs.items():
                        if getattr(net, k, None) != v:
                            break
                    else:
                        save(key)
        def walker_func(set_func):
            def walker(key, obj, walk, save):
                if obj is None:
                    return
                set_walker(key, set_func(obj), walk, save)
            return walker
        if id is not None:
            self._reqid += 1
            reqid = ('testobjectdb', self._reqid)
            for m in callAPI(self.apiroutine, 'objectdb', 'get', {'key' : LogicalNetwork.default_key(id), 'requestid': reqid}):
                yield m
            result = self.apiroutine.retvalue
            with watch_context([LogicalNetwork.default_key(id)], [result], reqid, self.apiroutine):
                if result is None:
                    self.apiroutine.retvalue = []
                    return
                if physicalnetwork is not None and physicalnetwork != result.physicalnetwork.id:
                    self.apiroutine.retvalue = []
                    return
                for k,v in kwargs.items():
                    if getattr(result, k, None) != v:
                        self.apiroutine.retvalue = []
                        return
                self.apiroutine.retvalue = [dump(result)]
        elif physicalnetwork is not None:
            self._reqid += 1
            reqid = ('testobjectdb', self._reqid)
            pm_key = PhysicalNetworkMap.default_key(physicalnetwork)
            for m in callAPI(self.apiroutine, 'objectdb', 'walk', {'keys': [pm_key],
                                                                   'walkerdict': {pm_key: walker_func(lambda x: x.networks)},
                                                                   'requestid': reqid}):
                yield m
            keys, result = self.apiroutine.retvalue
            with watch_context(keys, result, reqid, self.apiroutine):
                self.apiroutine.retvalue = [dump(r) for r in result]
        else:
            self._reqid += 1
            reqid = ('testobjectdb', self._reqid)
            ns_key = LogicalNetworkSet.default_key()
            for m in callAPI(self.apiroutine, 'objectdb', 'walk', {'keys': [ns_key],
                                                                   'walkerdict': {ns_key: walker_func(lambda x: x.set)},
                                                                   'requestid': reqid}):
                yield m
            keys, result = self.apiroutine.retvalue
            with watch_context(keys, result, reqid, self.apiroutine):
                self.apiroutine.retvalue = [dump(r) for r in result]
Exemplo n.º 10
0
class OVSDBManager(Module):
    '''
    Manage Openflow Connections
    '''
    service = True
    _default_vhostbind = None
    _default_bridgenames = None
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_conns
        self.routines.append(self.apiroutine)
        self.managed_conns = {}
        self.managed_systemids = {}
        self.managed_bridges = {}
        self.managed_routines = []
        self.endpoint_conns = {}
        self.createAPI(api(self.getconnection, self.apiroutine),
                       api(self.waitconnection, self.apiroutine),
                       api(self.getdatapathids, self.apiroutine),
                       api(self.getalldatapathids, self.apiroutine),
                       api(self.getallconnections, self.apiroutine),
                       api(self.getbridges, self.apiroutine),
                       api(self.getbridge, self.apiroutine),
                       api(self.getbridgebyuuid, self.apiroutine),
                       api(self.waitbridge, self.apiroutine),
                       api(self.waitbridgebyuuid, self.apiroutine),
                       api(self.getsystemids, self.apiroutine),
                       api(self.getallsystemids, self.apiroutine),
                       api(self.getconnectionbysystemid, self.apiroutine),
                       api(self.waitconnectionbysystemid, self.apiroutine),
                       api(self.getconnectionsbyendpoint, self.apiroutine),
                       api(self.getconnectionsbyendpointname, self.apiroutine),
                       api(self.getendpoints, self.apiroutine),
                       api(self.getallendpoints, self.apiroutine),
                       api(self.getallbridges, self.apiroutine),
                       api(self.getbridgeinfo, self.apiroutine),
                       api(self.waitbridgeinfo, self.apiroutine)
                       )
        self._synchronized = False
    def _update_bridge(self, connection, protocol, bridge_uuid, vhost):
        try:
            method, params = ovsdb.transact('Open_vSwitch',
                                            ovsdb.wait('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]],
                                                        ["datapath_id"], [{"datapath_id": ovsdb.oset()}], False, 5000),
                                            ovsdb.select('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]],
                                                                         ["datapath_id","name"]))
            for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                yield m
            r = self.apiroutine.jsonrpc_result[0]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error']))
            r = self.apiroutine.jsonrpc_result[1]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error']))
            if r['rows']:
                r0 = r['rows'][0]
                name = r0['name']
                dpid = int(r0['datapath_id'], 16)
                if self.bridgenames is None or name in self.bridgenames:
                    self.managed_bridges[connection].append((vhost, dpid, name, bridge_uuid))
                    self.managed_conns[(vhost, dpid)] = connection
                    for m in self.apiroutine.waitForSend(OVSDBBridgeSetup(OVSDBBridgeSetup.UP,
                                                               dpid,
                                                               connection.ovsdb_systemid,
                                                               name,
                                                               connection,
                                                               connection.connmark,
                                                               vhost,
                                                               bridge_uuid)):
                        yield m
        except JsonRPCProtocolException:
            pass
    def _get_bridges(self, connection, protocol):
        try:
            try:
                vhost = protocol.vhost
                if not hasattr(connection, 'ovsdb_systemid'):
                    method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Open_vSwitch', [], ['external_ids']))
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                        yield m
                    result = self.apiroutine.jsonrpc_result[0]
                    system_id = ovsdb.omap_getvalue(result['rows'][0]['external_ids'], 'system-id')
                    connection.ovsdb_systemid = system_id
                else:
                    system_id = connection.ovsdb_systemid
                if (vhost, system_id) in self.managed_systemids:
                    oc = self.managed_systemids[(vhost, system_id)]
                    ep = _get_endpoint(oc)
                    econns = self.endpoint_conns.get((vhost, ep))
                    if econns:
                        try:
                            econns.remove(oc)
                        except ValueError:
                            pass
                    del self.managed_systemids[(vhost, system_id)]
                self.managed_systemids[(vhost, system_id)] = connection
                self.managed_bridges[connection] = []
                ep = _get_endpoint(connection)
                self.endpoint_conns.setdefault((vhost, ep), []).append(connection)
                method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])})
                for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                    yield m
                if 'error' in self.apiroutine.jsonrpc_result:
                    # The monitor is already set, cancel it first
                    method, params = ovsdb.monitor_cancel('ovsdb_manager_bridges_monitor')
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine, False):
                        yield m
                    method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])})
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                        yield m
                    if 'error' in self.apiroutine.jsonrpc_result:
                        raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result))
            except Exception:
                for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)):
                    yield m
                raise
            else:
                # Process initial bridges
                init_subprocesses = [self._update_bridge(connection, protocol, buuid, vhost)
                                    for buuid in self.apiroutine.jsonrpc_result['Bridge'].keys()]
                def init_process():
                    try:
                        with closing(self.apiroutine.executeAll(init_subprocesses, retnames = ())) as g:
                            for m in g:
                                yield m
                    except Exception:
                        for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)):
                            yield m
                        raise
                    else:
                        for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)):
                            yield m
                self.apiroutine.subroutine(init_process())
            # Wait for notify
            notification = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark, _ismatch = lambda x: x.params[0] == 'ovsdb_manager_bridges_monitor')
            conn_down = protocol.statematcher(connection)
            while True:
                yield (conn_down, notification)
                if self.apiroutine.matcher is conn_down:
                    break
                else:
                    for buuid, v in self.apiroutine.event.params[1]['Bridge'].items():
                        # If a bridge's name or datapath-id is changed, we remove this bridge and add it again
                        if 'old' in v:
                            # A bridge is deleted
                            bridges = self.managed_bridges[connection]
                            for i in range(0, len(bridges)):
                                if buuid == bridges[i][3]:
                                    self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN,
                                                                               bridges[i][1],
                                                                               system_id,
                                                                               bridges[i][2],
                                                                               connection,
                                                                               connection.connmark,
                                                                               vhost,
                                                                               bridges[i][3],
                                                                               new_datapath_id =
                                                                                int(v['new']['datapath_id'], 16) if 'new' in v and 'datapath_id' in v['new']
                                                                                else None))
                                    del self.managed_conns[(vhost, bridges[i][1])]
                                    del bridges[i]
                                    break
                        if 'new' in v:
                            # A bridge is added
                            self.apiroutine.subroutine(self._update_bridge(connection, protocol, buuid, vhost))
        except JsonRPCProtocolException:
            pass
        finally:
            del connection._ovsdb_manager_get_bridges
    def _manage_existing(self):
        for m in callAPI(self.apiroutine, "jsonrpcserver", "getconnections", {}):
            yield m
        vb = self.vhostbind
        conns = self.apiroutine.retvalue
        for c in conns:
            if vb is None or c.protocol.vhost in vb:
                if not hasattr(c, '_ovsdb_manager_get_bridges'):
                    c._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(c, c.protocol))
        matchers = [OVSDBConnectionSetup.createMatcher(None, c, c.connmark) for c in conns]
        for m in self.apiroutine.waitForAll(*matchers):
            yield m
        self._synchronized = True
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m
    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),)    
    def _manage_conns(self):
        try:
            self.apiroutine.subroutine(self._manage_existing())
            vb = self.vhostbind
            if vb is not None:
                conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
                conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
            else:
                conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP)
                conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN)
            while True:
                yield (conn_up, conn_down)
                if self.apiroutine.matcher is conn_up:
                    if not hasattr(self.apiroutine.event.connection, '_ovsdb_manager_get_bridges'):
                        self.apiroutine.event.connection._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(self.apiroutine.event.connection, self.apiroutine.event.createby))
                else:
                    e = self.apiroutine.event
                    conn = e.connection
                    bridges = self.managed_bridges.get(conn)
                    if bridges is not None:
                        del self.managed_systemids[(e.createby.vhost, conn.ovsdb_systemid)]
                        del self.managed_bridges[conn]
                        for vhost, dpid, name, buuid in bridges:
                            del self.managed_conns[(vhost, dpid)]
                            self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN,
                                                                       dpid,
                                                                       conn.ovsdb_systemid,
                                                                       name,
                                                                       conn,
                                                                       conn.connmark,
                                                                       e.createby.vhost,
                                                                       buuid))
                        econns = self.endpoint_conns.get(_get_endpoint(conn))
                        if econns is not None:
                            try:
                                econns.remove(conn)
                            except ValueError:
                                pass
        finally:
            for c in self.managed_bridges.keys():
                if hasattr(c, '_ovsdb_manager_get_bridges'):
                    c._ovsdb_manager_get_bridges.close()
                bridges = self.managed_bridges.get(c)
                if bridges is not None:
                    for vhost, dpid, name, buuid in bridges:
                        del self.managed_conns[(vhost, dpid)]
                        self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN,
                                                                   dpid, 
                                                                   c.ovsdb_systemid, 
                                                                   name, 
                                                                   c, 
                                                                   c.connmark, 
                                                                   c.protocol.vhost,
                                                                   buuid))
    def getconnection(self, datapathid, vhost = ''):
        "Get current connection of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.managed_conns.get((vhost, datapathid))
    def waitconnection(self, datapathid, timeout = 30, vhost = ''):
        "Wait for a datapath connection"
        for m in self.getconnection(datapathid, vhost):
            yield m
        c = self.apiroutine.retvalue
        if c is None:
            for m in self.apiroutine.waitWithTimeout(timeout, 
                            OVSDBBridgeSetup.createMatcher(
                                    state = OVSDBBridgeSetup.UP,
                                    datapathid = datapathid, vhost = vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath is not connected')
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c
    def getdatapathids(self, vhost = ''):
        "Get All datapath IDs"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost]
    def getalldatapathids(self):
        "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.keys())
    def getallconnections(self, vhost = ''):
        "Get all connections from vhost. If vhost is None, return all connections from any host"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = list(self.managed_bridges.keys())
        else:
            self.apiroutine.retvalue = list(k for k in self.managed_bridges.keys() if k.protocol.vhost == vhost)
    def getbridges(self, connection):
        "Get all (dpid, name, _uuid) tuple on this connection"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            self.apiroutine.retvalue = [(dpid, name, buuid) for _, dpid, name, buuid in bridges]
        else:
            self.apiroutine.retvalue = None
    def getallbridges(self, vhost = None):
        "Get all (dpid, name, _uuid) tuple for all connections, optionally filtered by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is not None:
            self.apiroutine.retvalue = [(dpid, name, buuid)
                                        for c, bridges in self.managed_bridges.items()
                                        if c.protocol.vhost == vhost
                                        for _, dpid, name, buuid in bridges]
        else:
            self.apiroutine.retvalue = [(dpid, name, buuid)
                                        for c, bridges in self.managed_bridges.items()
                                        for _, dpid, name, buuid in bridges]
    def getbridge(self, connection, name):
        "Get datapath ID on this connection with specified name"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            for _, dpid, n, _ in bridges:
                if n == name:
                    self.apiroutine.retvalue = dpid
                    return
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = None
    def waitbridge(self, connection, name, timeout = 30):
        "Wait for bridge with specified name appears and return the datapath-id"
        bnames = self.bridgenames
        if bnames is not None and name not in bnames:
            raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear: it is not in the selected bridge names')
        for m in self.getbridge(connection, name):
            yield m
        if self.apiroutine.retvalue is None:
            bridge_setup = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.UP,
                                                         None,
                                                         None,
                                                         name,
                                                         connection
                                                         )
            conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN,
                                                                  connection,
                                                                  connection.connmark)
            for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear')
            elif self.apiroutine.matcher is conn_down:
                raise ConnectionResetException('Connection is down before bridge ' + repr(name) + ' appears')
            else:
                self.apiroutine.retvalue = self.apiroutine.event.datapathid
    def getbridgebyuuid(self, connection, uuid):
        "Get datapath ID of bridge on this connection with specified _uuid"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            for _, dpid, _, buuid in bridges:
                if buuid == uuid:
                    self.apiroutine.retvalue = dpid
                    return
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = None
    def waitbridgebyuuid(self, connection, uuid, timeout = 30):
        "Wait for bridge with specified _uuid appears and return the datapath-id"
        for m in self.getbridgebyuuid(connection, uuid):
            yield m
        if self.apiroutine.retvalue is None:
            bridge_setup = OVSDBBridgeSetup.createMatcher(state = OVSDBBridgeSetup.UP,
                                                         connection = connection,
                                                         bridgeuuid = uuid
                                                         )
            conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN,
                                                                  connection,
                                                                  connection.connmark)
            for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge ' + repr(uuid) + ' does not appear')
            elif self.apiroutine.matcher is conn_down:
                raise ConnectionResetException('Connection is down before bridge ' + repr(uuid) + ' appears')
            else:
                self.apiroutine.retvalue = self.apiroutine.event.datapathid
    def getsystemids(self, vhost = ''):
        "Get All system-ids"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.managed_systemids.keys() if k[0] == vhost]
    def getallsystemids(self):
        "Get all system-ids from any vhost. Return (vhost, system-id) pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_systemids.keys())
    def getconnectionbysystemid(self, systemid, vhost = ''):
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.managed_systemids.get((vhost, systemid))
    def waitconnectionbysystemid(self, systemid, timeout = 30, vhost = ''):
        "Wait for a connection with specified system-id"
        for m in self.getconnectionbysystemid(systemid, vhost):
            yield m
        c = self.apiroutine.retvalue
        if c is None:
            for m in self.apiroutine.waitWithTimeout(timeout, 
                            OVSDBConnectionSetup.createMatcher(
                                    systemid, None, None, vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath is not connected')
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c
    def getconnectionsbyendpoint(self, endpoint, vhost = ''):
        "Get connection by endpoint address (IP, IPv6 or UNIX socket address)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint))
    def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30):
        "Get connection by endpoint name (Domain name, IP or IPv6 address)"
        # Resolve the name
        if not name:
            endpoint = ''
            for m in self.getconnectionbyendpoint(endpoint, vhost):
                yield m
        else:
            request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED)
            # Resolve hostname
            for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)):
                yield m
            for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)):
                yield m
            if self.apiroutine.timeout:
                # Resolve is only allowed through asynchronous resolver
                #try:
                #    self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST)
                #except:
                raise IOError('Resolve hostname timeout: ' + name)
            else:
                if hasattr(self.apiroutine.event, 'error'):
                    raise IOError('Cannot resolve hostname: ' + name)
                resp = self.apiroutine.event.response
                for r in resp:
                    raddr = r[4]
                    if isinstance(raddr, tuple):
                        # Ignore port
                        endpoint = raddr[0]
                    else:
                        # Unix socket? This should not happen, but in case...
                        endpoint = raddr
                    for m in self.getconnectionsbyendpoint(endpoint, vhost):
                        yield m
                    if self.apiroutine.retvalue is not None:
                        break
    def getendpoints(self, vhost = ''):
        "Get all endpoints for vhost"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost]
    def getallendpoints(self):
        "Get all endpoints from any vhost. Return (vhost, endpoint) pairs."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.endpoint_conns.keys())
    def getbridgeinfo(self, datapathid, vhost = ''):
        "Get (bridgename, systemid, bridge_uuid) tuple from bridge datapathid"
        for m in self.getconnection(datapathid, vhost):
            yield m
        if self.apiroutine.retvalue is not None:
            c = self.apiroutine.retvalue
            bridges = self.managed_bridges.get(c)
            if bridges is not None:
                for _, dpid, n, buuid in bridges:
                    if dpid == datapathid:
                        self.apiroutine.retvalue = (n, c.ovsdb_systemid, buuid)
                        return
                self.apiroutine.retvalue = None
            else:
                self.apiroutine.retvalue = None
    def waitbridgeinfo(self, datapathid, timeout = 30, vhost = ''):
        "Wait for bridge with datapathid, and return (bridgename, systemid, bridge_uuid) tuple"
        for m in self.getbridgeinfo(datapathid, vhost):
            yield m
        if self.apiroutine.retvalue is None:
            for m in self.apiroutine.waitWithTimeout(timeout,
                        OVSDBBridgeSetup.createMatcher(
                                    OVSDBBridgeSetup.UP, datapathid,
                                    None, None, None, None,
                                    vhost)):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge 0x%016x does not appear before timeout' % (datapathid,))
            e = self.apiroutine.event
            self.apiroutine.retvalue = (e.name, e.systemid, e.bridgeuuid)
Exemplo n.º 11
0
class OVSDBPortManager(Module):
    '''
    Manage Ports from OVSDB Protocol
    '''
    service = True
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_ports
        self.routines.append(self.apiroutine)
        self.managed_ports = {}
        self.managed_ids = {}
        self.monitor_routines = set()
        self.ports_uuids = {}
        self.wait_portnos = {}
        self.wait_names = {}
        self.wait_ids = {}
        self.bridge_datapathid = {}
        self.createAPI(api(self.getports, self.apiroutine),
                       api(self.getallports, self.apiroutine),
                       api(self.getportbyid, self.apiroutine),
                       api(self.waitportbyid, self.apiroutine),
                       api(self.getportbyname, self.apiroutine),
                       api(self.waitportbyname, self.apiroutine),
                       api(self.getportbyno, self.apiroutine),
                       api(self.waitportbyno, self.apiroutine),
                       api(self.resync, self.apiroutine)
                       )
        self._synchronized = False
    def _get_interface_info(self, connection, protocol, buuid, interface_uuid, port_uuid):
        try:
            method, params = ovsdb.transact('Open_vSwitch',
                                            ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]],
                                                       ["ofport"], [{"ofport":ovsdb.oset()}], False, 5000),
                                            ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]],
                                                       ["ifindex"], [{"ifindex":ovsdb.oset()}], False, 5000),
                                            ovsdb.select('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]],
                                                                         ["_uuid", "name", "ifindex", "ofport", "type", "external_ids"]))
            for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                yield m
            r = self.apiroutine.jsonrpc_result[0]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error']))            
            r = self.apiroutine.jsonrpc_result[1]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error']))            
            r = self.apiroutine.jsonrpc_result[2]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error']))
            if not r['rows']:
                self.apiroutine.retvalue = []
                return
            r0 = r['rows'][0]
            if r0['ofport'] < 0:
                # Ignore this port because it is in an error state
                self.apiroutine.retvalue = []
                return
            r0['_uuid'] = r0['_uuid'][1]
            r0['ifindex'] = ovsdb.getoptional(r0['ifindex'])
            r0['external_ids'] = ovsdb.getdict(r0['external_ids'])
            if buuid not in self.bridge_datapathid:
                self.apiroutine.retvalue = []
                return
            else:
                datapath_id = self.bridge_datapathid[buuid]
            if 'iface-id' in r0['external_ids']:
                eid = r0['external_ids']['iface-id']
                r0['id'] = eid
                id_ports = self.managed_ids.setdefault((protocol.vhost, eid), [])
                id_ports.append((datapath_id, r0))
            else:
                r0['id'] = None
            self.managed_ports.setdefault((protocol.vhost, datapath_id),[]).append((port_uuid, r0))
            notify = False
            if (protocol.vhost, datapath_id, r0['ofport']) in self.wait_portnos:
                notify = True
                del self.wait_portnos[(protocol.vhost, datapath_id, r0['ofport'])]
            if (protocol.vhost, datapath_id, r0['name']) in self.wait_names:
                notify = True
                del self.wait_names[(protocol.vhost, datapath_id, r0['name'])]
            if (protocol.vhost, r0['id']) in self.wait_ids:
                notify = True
                del self.wait_ids[(protocol.vhost, r0['id'])]
            if notify:
                for m in self.apiroutine.waitForSend(OVSDBPortUpNotification(connection, r0['name'],
                                                                             r0['ofport'], r0['id'],
                                                                             protocol.vhost, datapath_id,
                                                                             port = r0)):
                    yield m
            self.apiroutine.retvalue = [r0]
        except JsonRPCProtocolException:
            self.apiroutine.retvalue = []
    def _remove_interface_id(self, connection, protocol, datapath_id, port):
        eid = port['id']
        eid_list = self.managed_ids.get((protocol.vhost, eid))
        for i in range(0, len(eid_list)):
            if eid_list[i][1]['_uuid'] == port['_uuid']:
                del eid_list[i]
                break
    def _remove_interface(self, connection, protocol, datapath_id, interface_uuid, port_uuid):
        ports = self.managed_ports.get((protocol.vhost, datapath_id))
        r = None
        if ports is not None:
            for i in range(0, len(ports)):
                if ports[i][1]['_uuid'] == interface_uuid:
                    r = ports[i][1]
                    if r['id']:
                        self._remove_interface_id(connection, protocol, datapath_id, r)
                    del ports[i]
                    break
            if not ports:
                del self.managed_ports[(protocol.vhost, datapath_id)]
        return r
    def _remove_all_interface(self, connection, protocol, datapath_id, port_uuid, buuid):
        ports = self.managed_ports.get((protocol.vhost, datapath_id))
        if ports is not None:
            removed_ports = [r for puuid, r in ports if puuid == port_uuid]
            not_removed_ports = [(puuid, r) for puuid, r in ports if puuid != port_uuid]
            ports[:len(not_removed_ports)] = not_removed_ports
            del ports[len(not_removed_ports):]
            for r in removed_ports:
                if r['id']:
                    self._remove_interface_id(connection, protocol, datapath_id, r)
            if not ports:
                del self.managed_ports[(protocol.vhost, datapath_id)]
            return removed_ports
        if port_uuid in self.ports_uuids and self.ports_uuids[port_uuid] == buuid:
            del self.ports_uuids[port_uuid]
        return []
    def _update_interfaces(self, connection, protocol, updateinfo, update = True):
        """
        There are several kinds of updates, they may appear together:
        
        1. New bridge created (or from initial updateinfo). We should add all the interfaces to the list.
        
        2. Bridge removed. Remove all the ports.
        
        3. name and datapath_id may be changed. We will consider this as a new bridge created, and an old
           bridge removed.
        
        4. Bridge ports modification, i.e. add/remove ports.
           a) Normally a port record is created/deleted together. A port record cannot exist without a
              bridge containing it.
               
           b) It is also possible that a port is removed from one bridge and added to another bridge, in
              this case the ports do not appear in the updateinfo
            
        5. Port interfaces modification, i.e. add/remove interfaces. The bridge record may not appear in this
           situation.
           
        We must consider these situations carefully and process them in correct order.
        """
        port_update = updateinfo.get('Port', {})
        bridge_update = updateinfo.get('Bridge', {})
        working_routines = []
        def process_bridge(buuid, uo):
            try:
                nv = uo['new']
                if 'datapath_id' in nv:
                    datapath_id = int(nv['datapath_id'], 16)
                    self.bridge_datapathid[buuid] = datapath_id
                elif buuid in self.bridge_datapathid:
                    datapath_id = self.bridge_datapathid[buuid]
                else:
                    # This should not happen, but just in case...
                    for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitbridge', {'connection': connection,
                                                                    'name': nv['name'],
                                                                    'timeout': 5}):
                        yield m
                    datapath_id = self.apiroutine.retvalue
                if 'ports' in nv:
                    nset = set((p for _,p in ovsdb.getlist(nv['ports'])))
                else:
                    nset = set()
                if 'old' in uo:
                    ov = uo['old']
                    if 'ports' in ov:
                        oset = set((p for _,p in ovsdb.getlist(ov['ports'])))
                    else:
                        # new ports are not really added; it is only sent because datapath_id is modified
                        nset = set()
                        oset = set()
                    if 'datapath_id' in ov:
                        old_datapathid = int(ov['datapath_id'], 16)
                    else:
                        old_datapathid = datapath_id
                else:
                    oset = set()
                    old_datapathid = datapath_id
                # For every deleted port, remove the interfaces with this port _uuid
                remove = []
                add_routine = []
                for puuid in oset - nset:
                    remove += self._remove_all_interface(connection, protocol, old_datapathid, puuid, buuid)
                # For every port not changed, check if the interfaces are modified;
                for puuid in oset.intersection(nset):
                    if puuid in port_update:
                        # The port is modified, there should be an 'old' set and 'new' set
                        pu = port_update[puuid]
                        if 'old' in pu:
                            poset = set((p for _,p in ovsdb.getlist(pu['old']['interfaces'])))
                        else:
                            poset = set()
                        if 'new' in pu:
                            pnset = set((p for _,p in ovsdb.getlist(pu['new']['interfaces'])))
                        else:
                            pnset = set()
                        # Remove old interfaces
                        remove += [r for r in 
                                   (self._remove_interface(connection, protocol, datapath_id, iuuid, puuid)
                                    for iuuid in (poset - pnset)) if r is not None]
                        # Prepare to add new interfaces
                        add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid)
                                        for iuuid in (pnset - poset)]
                # For every port added, add the interfaces
                def add_port_interfaces(puuid):
                    # If the uuid does not appear in update info, we have no choice but to query interfaces with select
                    # we cannot use data from other bridges; the port may be moved from a bridge which is not tracked
                    try:
                        method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Port',
                                                                                     [["_uuid", "==", ovsdb.uuid(puuid)]],
                                                                                     ["interfaces"]))
                        for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                            yield m
                        r = self.apiroutine.jsonrpc_result[0]
                        if 'error' in r:
                            raise JsonRPCErrorResultException('Error when query interfaces from port ' + repr(puuid) + ': ' + r['error'])
                        if r['rows']:
                            interfaces = ovsdb.getlist(r['rows'][0]['interfaces'])
                            with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, puuid)
                                                                 for _,iuuid in interfaces])) as g:
                                for m in g:
                                    yield m
                            self.apiroutine.retvalue = list(itertools.chain(r[0] for r in self.apiroutine.retvalue))
                        else:
                            self.apiroutine.retvalue = []
                    except JsonRPCProtocolException:
                        self.apiroutine.retvalue = []
                    except ConnectionResetException:
                        self.apiroutine.retvalue = []
                
                for puuid in nset - oset:
                    self.ports_uuids[puuid] = buuid
                    if puuid in port_update and 'new' in port_update[puuid] \
                            and 'old' not in port_update[puuid]:
                        # Add all the interfaces in 'new'
                        interfaces = ovsdb.getlist(port_update[puuid]['new']['interfaces'])
                        add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid)
                                        for _,iuuid in interfaces]
                    else:
                        add_routine.append(add_port_interfaces(puuid))
                # Execute the add_routine
                try:
                    with closing(self.apiroutine.executeAll(add_routine)) as g:
                        for m in g:
                            yield m
                except:
                    add = []
                    raise
                else:
                    add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue))
                finally:
                    if update:
                        self.scheduler.emergesend(
                                ModuleNotification(self.getServiceName(), 'update',
                                                   datapathid = datapath_id,
                                                   connection = connection,
                                                   vhost = protocol.vhost,
                                                   add = add, remove = remove,
                                                   reason = 'bridgemodify'
                                                            if 'old' in uo
                                                            else 'bridgeup'
                                                   ))
            except JsonRPCProtocolException:
                pass
            except ConnectionResetException:
                pass
            except OVSDBBridgeNotAppearException:
                pass
        ignore_ports = set()
        for buuid, uo in bridge_update.items():
            # Bridge removals are ignored because we process OVSDBBridgeSetup event instead
            if 'old' in uo:
                if 'ports' in uo['old']:
                    oset = set((puuid for _, puuid in ovsdb.getlist(uo['old']['ports'])))
                    ignore_ports.update(oset)
                if 'new' not in uo:
                    if buuid in self.bridge_datapathid:
                        del self.bridge_datapathid[buuid]
            if 'new' in uo:
                # If bridge contains this port is updated, we process the port update totally in bridge,
                # so we ignore it later
                if 'ports' in uo['new']:
                    nset = set((puuid for _, puuid in ovsdb.getlist(uo['new']['ports'])))
                    ignore_ports.update(nset)
                working_routines.append(process_bridge(buuid, uo))                
        def process_port(buuid, port_uuid, interfaces, remove_ids):
            if buuid not in self.bridge_datapathid:
                return
            datapath_id = self.bridge_datapathid[buuid]
            ports = self.managed_ports.get((protocol.vhost, datapath_id))
            remove = []
            if ports is not None:
                remove = [p for _,p in ports if p['_uuid'] in remove_ids]
                not_remove = [(_,p) for _,p in ports if p['_uuid'] not in remove_ids]
                ports[:len(not_remove)] = not_remove
                del ports[len(not_remove):]
            if interfaces:
                try:
                    with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, port_uuid)
                                                         for iuuid in interfaces])) as g:
                        for m in g:
                            yield m
                    add = list(itertools.chain((r[0] for r in self.apiroutine.retvalue if r[0])))
                except Exception:
                    self._logger.warning("Cannot get new port information", exc_info = True)
                    add = []
            else:
                add = []
            if update:
                for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = datapath_id,
                                                                                                          connection = connection,
                                                                                                          vhost = protocol.vhost,
                                                                                                          add = add, remove = remove,
                                                                                                          reason = 'bridgemodify'
                                                                                                          )):
                    yield m
        for puuid, po in port_update.items():
            if puuid not in ignore_ports:
                bridge_id = self.ports_uuids.get(puuid)
                if bridge_id is not None:
                    datapath_id = self.bridge_datapathid[bridge_id]
                    if datapath_id is not None:
                        # This port is modified
                        if 'new' in po:
                            nset = set((iuuid for _, iuuid in ovsdb.getlist(po['new']['interfaces'])))
                        else:
                            nset = set()                    
                        if 'old' in po:
                            oset = set((iuuid for _, iuuid in ovsdb.getlist(po['old']['interfaces'])))
                        else:
                            oset = set()
                        working_routines.append(process_port(bridge_id, puuid, nset - oset, oset - nset))
        if update:
            for r in working_routines:
                self.apiroutine.subroutine(r)
        else:
            try:
                with closing(self.apiroutine.executeAll(working_routines, None, ())) as g:
                    for m in g:
                        yield m
            finally:
                self.scheduler.emergesend(OVSDBConnectionPortsSynchronized(connection))
    def _get_ports(self, connection, protocol):
        try:
            try:
                method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_port_manager_interfaces_monitor', {
                                                    'Bridge':[ovsdb.monitor_request(["name", "datapath_id", "ports"])],
                                                    'Port':[ovsdb.monitor_request(["interfaces"])]
                                                })
                for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                    yield m
                if 'error' in self.apiroutine.jsonrpc_result:
                    # The monitor is already set, cancel it first
                    method2, params2 = ovsdb.monitor_cancel('ovsdb_port_manager_interfaces_monitor')
                    for m in protocol.querywithreply(method2, params2, connection, self.apiroutine, False):
                        yield m
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                        yield m
                    if 'error' in self.apiroutine.jsonrpc_result:
                        raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result))
                r = self.apiroutine.jsonrpc_result
            except:
                for m in self.apiroutine.waitForSend(OVSDBConnectionPortsSynchronized(connection)):
                    yield m
                raise
            # This is the initial state, it should contains all the ids of ports and interfaces
            self.apiroutine.subroutine(self._update_interfaces(connection, protocol, r, False))
            update_matcher = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark,
                                                                    _ismatch = lambda x: x.params[0] == 'ovsdb_port_manager_interfaces_monitor')
            conn_state = protocol.statematcher(connection)
            while True:
                yield (update_matcher, conn_state)
                if self.apiroutine.matcher is conn_state:
                    break
                else:
                    self.apiroutine.subroutine(self._update_interfaces(connection, protocol, self.apiroutine.event.params[1], True))
        except JsonRPCProtocolException:
            pass
        finally:
            if self.apiroutine.currentroutine in self.monitor_routines:
                self.monitor_routines.remove(self.apiroutine.currentroutine)
    def _get_existing_ports(self):
        for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getallconnections', {'vhost':None}):
            yield m
        matchers = []
        for c in self.apiroutine.retvalue:
            self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(c, c.protocol)))
            matchers.append(OVSDBConnectionPortsSynchronized.createMatcher(c))
        for m in self.apiroutine.waitForAll(*matchers):
            yield m
        self._synchronized = True
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m
    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),)
    def _manage_ports(self):
        try:
            self.apiroutine.subroutine(self._get_existing_ports())
            connsetup = OVSDBConnectionSetup.createMatcher()
            bridgedown = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.DOWN)
            while True:
                yield (connsetup, bridgedown)
                e = self.apiroutine.event
                if self.apiroutine.matcher is connsetup:
                    self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(e.connection, e.connection.protocol)))
                else:
                    # Remove ports of the bridge
                    ports =  self.managed_ports.get((e.vhost, e.datapathid))
                    if ports is not None:
                        ports_original = ports
                        ports = [p for _,p in ports]
                        for p in ports:
                            if p['id']:
                                self._remove_interface_id(e.connection,
                                                          e.connection.protocol, e.datapathid, p)
                        newdpid = getattr(e, 'new_datapath_id', None)
                        buuid = e.bridgeuuid
                        if newdpid is not None:
                            # This bridge changes its datapath id
                            if buuid in self.bridge_datapathid and self.bridge_datapathid[buuid] == e.datapathid:
                                self.bridge_datapathid[buuid] = newdpid
                            def re_add_interfaces():
                                with closing(self.apiroutine.executeAll(
                                    [self._get_interface_info(e.connection, e.connection.protocol, buuid,
                                                              r['_uuid'], puuid)
                                     for puuid, r in ports_original])) as g:
                                    for m in g:
                                        yield m
                                add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue))
                                for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(),
                                                  'update', datapathid = e.datapathid,
                                                  connection = e.connection,
                                                  vhost = e.vhost,
                                                  add = add, remove = [],
                                                  reason = 'bridgeup'
                                                  )):
                                    yield m
                            self.apiroutine.subroutine(re_add_interfaces())
                        else:
                            # The ports are removed
                            for puuid, _ in ports_original:
                                if puuid in self.ports_uuids[puuid] and self.ports_uuids[puuid] == buuid:
                                    del self.ports_uuids[puuid]
                        del self.managed_ports[(e.vhost, e.datapathid)]
                        self.scheduler.emergesend(ModuleNotification(self.getServiceName(),
                                                  'update', datapathid = e.datapathid,
                                                  connection = e.connection,
                                                  vhost = e.vhost,
                                                  add = [], remove = ports,
                                                  reason = 'bridgedown'
                                                  ))                            
        finally:
            for r in list(self.monitor_routines):
                r.close()
            self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized'))
    def getports(self, datapathid, vhost = ''):
        "Return all ports of a specifed datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [p for _,p in self.managed_ports.get((vhost, datapathid), [])]
    def getallports(self, vhost = None):
        "Return all (datapathid, port, vhost) tuples, optionally filterd by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() for _,p in v]
        else:
            self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() if vh == vhost for _,p in v]
    def getportbyno(self, datapathid, portno, vhost = ''):
        "Return port with specified portno"
        portno &= 0xffff
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost)
    def _getportbyno(self, datapathid, portno, vhost = ''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for _, p in ports:
                if p['ofport'] == portno:
                    return p
            return None
    def waitportbyno(self, datapathid, portno, timeout = 30, vhost = ''):
        "Wait for port with specified portno"
        portno &= 0xffff
        for m in self._wait_for_sync():
            yield m
        def waitinner():
            p = self._getportbyno(datapathid, portno, vhost)
            if p is not None:
                self.apiroutine.retvalue = p
            else:
                try:
                    self.wait_portnos[(vhost, datapathid, portno)] = \
                            self.wait_portnos.get((vhost, datapathid, portno),0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(None, None, portno, None, vhost, datapathid),)
                except:
                    v = self.wait_portnos.get((vhost, datapathid, portno))
                    if v is not None:
                        if v <= 1:
                            del self.wait_portnos[(vhost, datapathid, portno)]
                        else:
                            self.wait_portnos[(vhost, datapathid, portno)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = self.apiroutine.event.port
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException('Port ' + repr(portno) + ' does not appear before timeout')
    def getportbyname(self, datapathid, name, vhost = ''):
        "Return port with specified name"
        if isinstance(name, bytes):
            name = name.decode('utf-8')
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost)
    def _getportbyname(self, datapathid, name, vhost = ''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for _, p in ports:
                if p['name'] == name:
                    return p
            return None
    def waitportbyname(self, datapathid, name, timeout = 30, vhost = ''):
        "Wait for port with specified name"
        for m in self._wait_for_sync():
            yield m
        def waitinner():
            p = self._getportbyname(datapathid, name, vhost)
            if p is not None:
                self.apiroutine.retvalue = p
            else:
                try:
                    self.wait_names[(vhost, datapathid, name)] = \
                            self.wait_portnos.get((vhost, datapathid, name) ,0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(None, name, None, None, vhost, datapathid),)
                except:
                    v = self.wait_names.get((vhost, datapathid, name))
                    if v is not None:
                        if v <= 1:
                            del self.wait_names[(vhost, datapathid, name)]
                        else:
                            self.wait_names[(vhost, datapathid, name)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = self.apiroutine.event.port
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException('Port ' + repr(name) + ' does not appear before timeout')
    def getportbyid(self, id, vhost = ''):
        "Return port with the specified id. The return value is a pair: (datapath_id, port)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine = self._getportbyid(id, vhost)
    def _getportbyid(self, id, vhost = ''):
        ports = self.managed_ids.get((vhost, id))
        if ports:
            return ports[0]
        else:
            return None
    def waitportbyid(self, id, timeout = 30, vhost = ''):
        "Wait for port with the specified id. The return value is a pair (datapath_id, port)"
        for m in self._wait_for_sync():
            yield m
        def waitinner():
            p = self._getportbyid(id, vhost)
            if p is None:
                try:
                    self.wait_ids[(vhost, id)] = self.wait_ids.get((vhost, id), 0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(None, None, None, id, vhost),)
                except:
                    v = self.wait_ids.get((vhost, id))
                    if v is not None:
                        if v <= 1:
                            del self.wait_ids[(vhost, id)]
                        else:
                            self.wait_ids[(vhost, id)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = (self.apiroutine.event.datapathid,
                                                self.apiroutine.event.port)
            else:
                self.apiroutine.retvalue = p
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException('Port ' + repr(id) + ' does not appear before timeout')
    def resync(self, datapathid, vhost = ''):
        '''
        Resync with current ports
        '''
        # Sometimes when the OVSDB connection is very busy, monitor message may be dropped.
        # We must deal with this and recover from it
        # Save current manged_ports
        if (vhost, datapathid) not in self.managed_ports:
            self.apiroutine.retvalue = None
            return
        else:
            for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getconnection', {'datapathid': datapathid, 'vhost':vhost}):
                yield m
            c = self.apiroutine.retvalue
            if c is not None:
                # For now, we restart the connection...
                for m in c.reconnect(False):
                    yield m
                for m in self.apiroutine.waitWithTimeout(0.1):
                    yield m
                for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitconnection', {'datapathid': datapathid,
                                                                                     'vhost': vhost}):
                    yield m
        self.apiroutine.retvalue = None