示例#1
0
class TestModule(Module):
    _default_serverlist = ['tcp://localhost:3181/','tcp://localhost:3182/','tcp://localhost:3183/']
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.client = ZooKeeperClient(self.apiroutine, self.serverlist)
        self.connections.append(self.client)
        self.apiroutine.main = self.main
        self.routines.append(self.apiroutine)
    def watcher(self):
        watcher = ZooKeeperWatcherEvent.createMatcher()
        while True:
            yield (watcher,)
            print('WatcherEvent: %r' % (dump(self.apiroutine.event.message),))
    def main(self):
        def _watch(w):
            for m in w.wait(self.apiroutine):
                yield m
            print('Watcher returns:', dump(self.apiroutine.retvalue))
        def _watchall(watchers):
            for w in watchers:
                if w is not None:
                    self.apiroutine.subroutine(_watch(w))
        self.apiroutine.subroutine(self.watcher(), False, daemon = True)
        up = ZooKeeperSessionStateChanged.createMatcher(ZooKeeperSessionStateChanged.CREATED, self.client)
        yield (up,)
        print('Connection is up: %r' % (self.client.currentserver,))
        for m in self.client.requests([zk.create(b'/vlcptest', b'test'),
                                       zk.getdata(b'/vlcptest', True)], self.apiroutine):
            yield m
        print(self.apiroutine.retvalue)
        pprint(dump(self.apiroutine.retvalue[0]))
        _watchall(self.apiroutine.retvalue[3])
        for m in self.apiroutine.waitWithTimeout(0.2):
            yield m
        for m in self.client.requests([zk.delete(b'/vlcptest'),
                                        zk.getdata(b'/vlcptest', watch = True)], self.apiroutine):
            yield m
        print(self.apiroutine.retvalue)
        pprint(dump(self.apiroutine.retvalue[0]))
        _watchall(self.apiroutine.retvalue[3])
        for m in self.client.requests([zk.multi(
                                        zk.multi_create(b'/vlcptest2', b'test'),
                                        zk.multi_create(b'/vlcptest2/subtest', 'test2')
                                    ),
                                  zk.getchildren2(b'/vlcptest2', True)], self.apiroutine):
            yield m
        print(self.apiroutine.retvalue)
        pprint(dump(self.apiroutine.retvalue[0]))
        _watchall(self.apiroutine.retvalue[3])
        for m in self.client.requests([zk.multi(
                                        zk.multi_delete(b'/vlcptest2/subtest'),
                                        zk.multi_delete(b'/vlcptest2')),
                                  zk.getchildren2(b'/vlcptest2', True)], self.apiroutine):
            yield m
        print(self.apiroutine.retvalue)
        pprint(dump(self.apiroutine.retvalue[0]))
        _watchall(self.apiroutine.retvalue[3])
示例#2
0
class TestModule(Module):
    _default_url = 'tcp://localhost/'
    _default_sessiontimeout = 30
    def __init__(self, server):
        Module.__init__(self, server)
        self.protocol = ZooKeeper()
        self.client = Client(self.url, self.protocol, self.scheduler)
        self.connections.append(self.client)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self.main
        self.routines.append(self.apiroutine)
    def watcher(self):
        watcher = ZooKeeperWatcherEvent.createMatcher(connection = self.client)
        while True:
            yield (watcher,)
            print('WatcherEvent: %r' % (dump(self.apiroutine.event.message),))
    def main(self):
        self.apiroutine.subroutine(self.watcher(), False, daemon = True)
        up = ZooKeeperConnectionStateEvent.createMatcher(ZooKeeperConnectionStateEvent.UP, self.client)
        notconn = ZooKeeperConnectionStateEvent.createMatcher(ZooKeeperConnectionStateEvent.NOTCONNECTED, self.client)
        yield (up, notconn)
        if self.apiroutine.matcher is notconn:
            print('Not connected')
            return
        else:
            print('Connection is up: %r' % (self.client,))
        # Handshake
        for m in self.protocol.handshake(self.client, zk.ConnectRequest(
                                                        timeOut = int(self.sessiontimeout * 1000),
                                                        passwd = b'\x00' * 16,      # Why is it necessary...
                                                    ), self.apiroutine, []):
            yield m
        for m in self.protocol.requests(self.client, [zk.create(b'/vlcptest', b'test'),
                                                      zk.getdata(b'/vlcptest', True)], self.apiroutine):
            yield m
        pprint(dump(self.apiroutine.retvalue[0]))
        for m in self.apiroutine.waitWithTimeout(0.2):
            yield m
        for m in self.protocol.requests(self.client, [zk.delete(b'/vlcptest'),
                                                      zk.getdata(b'/vlcptest', watch = True)], self.apiroutine):
            yield m
        pprint(dump(self.apiroutine.retvalue[0]))
        for m in self.protocol.requests(self.client, [zk.multi(
                                                            zk.multi_create(b'/vlcptest2', b'test'),
                                                            zk.multi_create(b'/vlcptest2/subtest', 'test2')
                                                        ),
                                                      zk.getchildren2(b'/vlcptest2', True)], self.apiroutine):
            yield m
        pprint(dump(self.apiroutine.retvalue[0]))
        for m in self.protocol.requests(self.client, [zk.multi(
                                                            zk.multi_delete(b'/vlcptest2/subtest'),
                                                            zk.multi_delete(b'/vlcptest2')),
                                                      zk.getchildren2(b'/vlcptest2', True)], self.apiroutine):
            yield m
        pprint(dump(self.apiroutine.retvalue[0]))
示例#3
0
class OpenflowManager(Module):
    '''
    Manage Openflow Connections
    '''
    service = True
    _default_vhostbind = None

    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_conns
        self.routines.append(self.apiroutine)
        self.managed_conns = {}
        self.endpoint_conns = {}
        self.table_modules = set()
        self._acquiring = False
        self._acquire_updated = False
        self._lastacquire = None
        self._synchronized = False
        self.createAPI(api(self.getconnections, self.apiroutine),
                       api(self.getconnection, self.apiroutine),
                       api(self.waitconnection, self.apiroutine),
                       api(self.getdatapathids, self.apiroutine),
                       api(self.getalldatapathids, self.apiroutine),
                       api(self.getallconnections, self.apiroutine),
                       api(self.getconnectionsbyendpoint, self.apiroutine),
                       api(self.getconnectionsbyendpointname, self.apiroutine),
                       api(self.getendpoints, self.apiroutine),
                       api(self.getallendpoints, self.apiroutine),
                       api(self.acquiretable, self.apiroutine),
                       api(self.unacquiretable, self.apiroutine),
                       api(self.lastacquiredtables))

    def _add_connection(self, conn):
        vhost = conn.protocol.vhost
        conns = self.managed_conns.setdefault(
            (vhost, conn.openflow_datapathid), [])
        remove = []
        for i in range(0, len(conns)):
            if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid:
                ci = conns[i]
                remove = [ci]
                ep = _get_endpoint(ci)
                econns = self.endpoint_conns.get((vhost, ep))
                if econns is not None:
                    try:
                        econns.remove(ci)
                    except ValueError:
                        pass
                    if not econns:
                        del self.endpoint_conns[(vhost, ep)]
                del conns[i]
                break
        conns.append(conn)
        ep = _get_endpoint(conn)
        econns = self.endpoint_conns.setdefault((vhost, ep), [])
        econns.append(conn)
        if self._lastacquire and conn.openflow_auxiliaryid == 0:
            self.apiroutine.subroutine(self._initialize_connection(conn))
        return remove

    def _initialize_connection(self, conn):
        ofdef = conn.openflowdef
        flow_mod = ofdef.ofp_flow_mod(buffer_id=ofdef.OFP_NO_BUFFER,
                                      out_port=ofdef.OFPP_ANY,
                                      command=ofdef.OFPFC_DELETE)
        if hasattr(ofdef, 'OFPG_ANY'):
            flow_mod.out_group = ofdef.OFPG_ANY
        if hasattr(ofdef, 'OFPTT_ALL'):
            flow_mod.table_id = ofdef.OFPTT_ALL
        if hasattr(ofdef, 'ofp_match_oxm'):
            flow_mod.match = ofdef.ofp_match_oxm()
        cmds = [flow_mod]
        if hasattr(ofdef, 'ofp_group_mod'):
            group_mod = ofdef.ofp_group_mod(command=ofdef.OFPGC_DELETE,
                                            group_id=ofdef.OFPG_ALL)
            cmds.append(group_mod)
        for m in conn.protocol.batch(cmds, conn, self.apiroutine):
            yield m
        if hasattr(ofdef, 'ofp_instruction_goto_table'):
            # Create default flows
            vhost = conn.protocol.vhost
            if self._lastacquire and vhost in self._lastacquire:
                _, pathtable = self._lastacquire[vhost]
                cmds = [
                    ofdef.ofp_flow_mod(table_id=t[i][1],
                                       command=ofdef.OFPFC_ADD,
                                       priority=0,
                                       buffer_id=ofdef.OFP_NO_BUFFER,
                                       out_port=ofdef.OFPP_ANY,
                                       out_group=ofdef.OFPG_ANY,
                                       match=ofdef.ofp_match_oxm(),
                                       instructions=[
                                           ofdef.ofp_instruction_goto_table(
                                               table_id=t[i + 1][1])
                                       ]) for _, t in pathtable.items()
                    for i in range(0,
                                   len(t) - 1)
                ]
                if cmds:
                    for m in conn.protocol.batch(cmds, conn, self.apiroutine):
                        yield m
        for m in self.apiroutine.waitForSend(
                FlowInitialize(conn, conn.openflow_datapathid,
                               conn.protocol.vhost)):
            yield m

    def _acquire_tables(self):
        try:
            while self._acquire_updated:
                result = None
                exception = None
                # Delay the update so we are not updating table acquires for every module
                for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()):
                    yield m
                yield (TableAcquireDelayEvent.createMatcher(), )
                module_list = list(self.table_modules)
                self._acquire_updated = False
                try:
                    for m in self.apiroutine.executeAll(
                        (callAPI(self.apiroutine, module, 'gettablerequest',
                                 {}) for module in module_list)):
                        yield m
                except QuitException:
                    raise
                except Exception as exc:
                    self._logger.exception('Acquiring table failed')
                    exception = exc
                else:
                    requests = [r[0] for r in self.apiroutine.retvalue]
                    vhosts = set(vh for _, vhs in requests if vhs is not None
                                 for vh in vhs)
                    vhost_result = {}
                    # Requests should be list of (name, (ancester, ancester, ...), pathname)
                    for vh in vhosts:
                        graph = {}
                        table_path = {}
                        try:
                            for r in requests:
                                if r[1] is None or vh in r[1]:
                                    for name, ancesters, pathname in r[0]:
                                        if name in table_path:
                                            if table_path[name] != pathname:
                                                raise ValueError(
                                                    "table conflict detected: %r can not be in two path: %r and %r"
                                                    % (name, table_path[name],
                                                       pathname))
                                        else:
                                            table_path[name] = pathname
                                        if name not in graph:
                                            graph[name] = (set(ancesters),
                                                           set())
                                        else:
                                            graph[name][0].update(ancesters)
                                        for anc in ancesters:
                                            graph.setdefault(
                                                anc,
                                                (set(), set()))[1].add(name)
                        except ValueError as exc:
                            self._logger.error(str(exc))
                            exception = exc
                            break
                        else:
                            sequences = []

                            def dfs_sort(current):
                                sequences.append(current)
                                for d in graph[current][1]:
                                    anc = graph[d][0]
                                    anc.remove(current)
                                    if not anc:
                                        dfs_sort(d)

                            nopre_tables = sorted(
                                [k for k, v in graph.items() if not v[0]],
                                key=lambda x: (table_path.get(name, ''), name))
                            for t in nopre_tables:
                                dfs_sort(t)
                            if len(sequences) < len(graph):
                                rest_tables = set(
                                    graph.keys()).difference(sequences)
                                self._logger.error(
                                    "Circle detected in table acquiring, following tables are related: %r, vhost = %r",
                                    sorted(rest_tables), vh)
                                self._logger.error(
                                    "Circle dependencies are: %s", ", ".join(
                                        repr(tuple(graph[t][0])) + "=>" + t
                                        for t in rest_tables))
                                exception = ValueError(
                                    "Circle detected in table acquiring, following tables are related: %r, vhost = %r"
                                    % (sorted(rest_tables), vh))
                                break
                            elif len(sequences) > 255:
                                self._logger.error(
                                    "Table limit exceeded: %d tables (only 255 allowed), vhost = %r",
                                    len(sequences), vh)
                                exception = ValueError(
                                    "Table limit exceeded: %d tables (only 255 allowed), vhost = %r"
                                    % (len(sequences), vh))
                                break
                            else:
                                full_indices = list(
                                    zip(sequences, itertools.count()))
                                tables = dict(
                                    (k, tuple(g))
                                    for k, g in itertools.groupby(
                                        sorted(full_indices,
                                               key=lambda x: table_path.get(
                                                   x[0], '')),
                                        lambda x: table_path.get(x[0], '')))
                                vhost_result[vh] = (full_indices, tables)
        finally:
            self._acquiring = False
        if exception:
            for m in self.apiroutine.waitForSend(
                    TableAcquireUpdate(exception=exception)):
                yield m
        else:
            result = vhost_result
            if result != self._lastacquire:
                self._lastacquire = result
                self._reinitall()
            for m in self.apiroutine.waitForSend(
                    TableAcquireUpdate(result=result)):
                yield m

    def load(self, container):
        self.scheduler.queue.addSubQueue(
            1, TableAcquireDelayEvent.createMatcher(),
            'ofpmanager_tableacquiredelay')
        for m in container.waitForSend(TableAcquireUpdate(result=None)):
            yield m
        for m in Module.load(self, container):
            yield m

    def unload(self, container, force=False):
        for m in Module.unload(self, container, force=force):
            yield m
        for m in container.syscall(
                syscall_removequeue(self.scheduler.queue,
                                    'ofpmanager_tableacquiredelay')):
            yield m

    def _reinitall(self):
        for cl in self.managed_conns.values():
            for c in cl:
                self.apiroutine.subroutine(self._initialize_connection(c))

    def _manage_existing(self):
        for m in callAPI(self.apiroutine, "openflowserver", "getconnections",
                         {}):
            yield m
        vb = self.vhostbind
        for c in self.apiroutine.retvalue:
            if vb is None or c.protocol.vhost in vb:
                self._add_connection(c)
        self._synchronized = True
        for m in self.apiroutine.waitForSend(
                ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m

    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(),
                                                    'synchronized'), )

    def _manage_conns(self):
        vb = self.vhostbind
        self.apiroutine.subroutine(self._manage_existing(), False)
        try:
            if vb is not None:
                conn_up = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_SETUP,
                    _ismatch=lambda x: x.createby.vhost in vb)
                conn_down = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_DOWN,
                    _ismatch=lambda x: x.createby.vhost in vb)
            else:
                conn_up = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_SETUP)
                conn_down = OpenflowConnectionStateEvent.createMatcher(
                    state=OpenflowConnectionStateEvent.CONNECTION_DOWN)
            while True:
                yield (conn_up, conn_down)
                if self.apiroutine.matcher is conn_up:
                    e = self.apiroutine.event
                    remove = self._add_connection(e.connection)
                    self.scheduler.emergesend(
                        ModuleNotification(self.getServiceName(),
                                           'update',
                                           add=[e.connection],
                                           remove=remove))
                else:
                    e = self.apiroutine.event
                    conns = self.managed_conns.get(
                        (e.createby.vhost, e.datapathid))
                    remove = []
                    if conns is not None:
                        try:
                            conns.remove(e.connection)
                        except ValueError:
                            pass
                        else:
                            remove.append(e.connection)

                        if not conns:
                            del self.managed_conns[(e.createby.vhost,
                                                    e.datapathid)]
                        # Also delete from endpoint_conns
                        ep = _get_endpoint(e.connection)
                        econns = self.endpoint_conns.get(
                            (e.createby.vhost, ep))
                        if econns is not None:
                            try:
                                econns.remove(e.connection)
                            except ValueError:
                                pass
                            if not econns:
                                del self.endpoint_conns[(e.createby.vhost, ep)]
                    if remove:
                        self.scheduler.emergesend(
                            ModuleNotification(self.getServiceName(),
                                               'update',
                                               add=[],
                                               remove=remove))
        finally:
            self.scheduler.emergesend(
                ModuleNotification(self.getServiceName(), 'unsynchronized'))

    def getconnections(self, datapathid, vhost=''):
        "Return all connections of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(
            self.managed_conns.get((vhost, datapathid), []))

    def getconnection(self, datapathid, auxiliaryid=0, vhost=''):
        "Get current connection of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid,
                                                       vhost)

    def _getconnection(self, datapathid, auxiliaryid=0, vhost=''):
        conns = self.managed_conns.get((vhost, datapathid))
        if conns is None:
            return None
        else:
            for c in conns:
                if c.openflow_auxiliaryid == auxiliaryid:
                    return c
            return None

    def waitconnection(self, datapathid, auxiliaryid=0, timeout=30, vhost=''):
        "Wait for a datapath connection"
        for m in self._wait_for_sync():
            yield m
        c = self._getconnection(datapathid, auxiliaryid, vhost)
        if c is None:
            for m in self.apiroutine.waitWithTimeout(
                    timeout,
                    OpenflowConnectionStateEvent.createMatcher(
                        datapathid,
                        auxiliaryid,
                        OpenflowConnectionStateEvent.CONNECTION_SETUP,
                        _ismatch=lambda x: x.createby.vhost == vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException(
                    'Datapath %016x is not connected' % datapathid)
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c

    def getdatapathids(self, vhost=''):
        "Get All datapath IDs"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            k[1] for k in self.managed_conns.keys() if k[0] == vhost
        ]

    def getalldatapathids(self):
        "Get all datapath IDs from any vhost. Return ``(vhost, datapathid)`` pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.keys())

    def getallconnections(self, vhost=''):
        "Get all connections from vhost. If vhost is None, return all connections from any host"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = list(
                itertools.chain(self.managed_conns.values()))
        else:
            self.apiroutine.retvalue = list(
                itertools.chain(v for k, v in self.managed_conns.items()
                                if k[0] == vhost))

    def getconnectionsbyendpoint(self, endpoint, vhost=''):
        "Get connection by endpoint address (IP, IPv6 or UNIX socket address)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint))

    def getconnectionsbyendpointname(self, name, vhost='', timeout=30):
        "Get connection by endpoint name (Domain name, IP or IPv6 address)"
        # Resolve the name
        if not name:
            endpoint = ''
            for m in self.getconnectionbyendpoint(endpoint, vhost):
                yield m
        else:
            request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM,
                       socket.IPPROTO_TCP,
                       socket.AI_ADDRCONFIG | socket.AI_V4MAPPED)
            # Resolve hostname
            for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)):
                yield m
            for m in self.apiroutine.waitWithTimeout(
                    timeout, ResolveResponseEvent.createMatcher(request)):
                yield m
            if self.apiroutine.timeout:
                # Resolve is only allowed through asynchronous resolver
                #try:
                #    self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST)
                #except:
                raise IOError('Resolve hostname timeout: ' + name)
            else:
                if hasattr(self.apiroutine.event, 'error'):
                    raise IOError('Cannot resolve hostname: ' + name)
                resp = self.apiroutine.event.response
                for r in resp:
                    raddr = r[4]
                    if isinstance(raddr, tuple):
                        # Ignore port
                        endpoint = raddr[0]
                    else:
                        # Unix socket? This should not happen, but in case...
                        endpoint = raddr
                    for m in self.getconnectionsbyendpoint(endpoint, vhost):
                        yield m
                    if self.apiroutine.retvalue is not None:
                        break

    def getendpoints(self, vhost=''):
        "Get all endpoints for vhost"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [
            k[1] for k in self.endpoint_conns if k[0] == vhost
        ]

    def getallendpoints(self):
        "Get all endpoints from any vhost. Return ``(vhost, endpoint)`` pairs."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.endpoint_conns.keys())

    def lastacquiredtables(self, vhost=""):
        "Get acquired table IDs"
        return self._lastacquire.get(vhost)

    def acquiretable(self, modulename):
        "Start to acquire tables for a module on module loading."
        if not modulename in self.table_modules:
            self.table_modules.add(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield

    def unacquiretable(self, modulename):
        "When module is unloaded, stop acquiring tables for this module."
        if modulename in self.table_modules:
            self.table_modules.remove(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield
示例#4
0
class OpenflowManager(Module):
    '''
    Manage Openflow Connections
    '''
    service = True
    _default_vhostbind = None
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_conns
        self.routines.append(self.apiroutine)
        self.managed_conns = {}
        self.endpoint_conns = {}
        self.table_modules = set()
        self._acquiring = False
        self._acquire_updated = False
        self._lastacquire = None
        self._synchronized = False
        self.createAPI(api(self.getconnections, self.apiroutine),
                       api(self.getconnection, self.apiroutine),
                       api(self.waitconnection, self.apiroutine),
                       api(self.getdatapathids, self.apiroutine),
                       api(self.getalldatapathids, self.apiroutine),
                       api(self.getallconnections, self.apiroutine),
                       api(self.getconnectionsbyendpoint, self.apiroutine),
                       api(self.getconnectionsbyendpointname, self.apiroutine),
                       api(self.getendpoints, self.apiroutine),
                       api(self.getallendpoints, self.apiroutine),
                       api(self.acquiretable, self.apiroutine),
                       api(self.unacquiretable, self.apiroutine),
                       api(self.lastacquiredtables)
                       )
    def _add_connection(self, conn):
        vhost = conn.protocol.vhost
        conns = self.managed_conns.setdefault((vhost, conn.openflow_datapathid), [])
        remove = []
        for i in range(0, len(conns)):
            if conns[i].openflow_auxiliaryid == conn.openflow_auxiliaryid:
                ci = conns[i]
                remove = [ci]
                ep = _get_endpoint(ci)
                econns = self.endpoint_conns.get((vhost, ep))
                if econns is not None:
                    try:
                        econns.remove(ci)
                    except ValueError:
                        pass
                    if not econns:
                        del self.endpoint_conns[(vhost, ep)]
                del conns[i]
                break
        conns.append(conn)
        ep = _get_endpoint(conn)
        econns = self.endpoint_conns.setdefault((vhost, ep), [])
        econns.append(conn)
        if self._lastacquire and conn.openflow_auxiliaryid == 0:
            self.apiroutine.subroutine(self._initialize_connection(conn))
        return remove
    def _initialize_connection(self, conn):
        ofdef = conn.openflowdef
        flow_mod = ofdef.ofp_flow_mod(buffer_id = ofdef.OFP_NO_BUFFER,
                                                 out_port = ofdef.OFPP_ANY,
                                                 command = ofdef.OFPFC_DELETE
                                                 )
        if hasattr(ofdef, 'OFPG_ANY'):
            flow_mod.out_group = ofdef.OFPG_ANY
        if hasattr(ofdef, 'OFPTT_ALL'):
            flow_mod.table_id = ofdef.OFPTT_ALL
        if hasattr(ofdef, 'ofp_match_oxm'):
            flow_mod.match = ofdef.ofp_match_oxm()
        cmds = [flow_mod]
        if hasattr(ofdef, 'ofp_group_mod'):
            group_mod = ofdef.ofp_group_mod(command = ofdef.OFPGC_DELETE,
                                            group_id = ofdef.OFPG_ALL
                                            )
            cmds.append(group_mod)
        for m in conn.protocol.batch(cmds, conn, self.apiroutine):
            yield m
        if hasattr(ofdef, 'ofp_instruction_goto_table'):
            # Create default flows
            vhost = conn.protocol.vhost
            if self._lastacquire and vhost in self._lastacquire:
                _, pathtable = self._lastacquire[vhost]
                cmds = [ofdef.ofp_flow_mod(table_id = t[i][1],
                                             command = ofdef.OFPFC_ADD,
                                             priority = 0,
                                             buffer_id = ofdef.OFP_NO_BUFFER,
                                             out_port = ofdef.OFPP_ANY,
                                             out_group = ofdef.OFPG_ANY,
                                             match = ofdef.ofp_match_oxm(),
                                             instructions = [ofdef.ofp_instruction_goto_table(table_id = t[i+1][1])]
                                       )
                          for _,t in pathtable.items()
                          for i in range(0, len(t) - 1)]
                if cmds:
                    for m in conn.protocol.batch(cmds, conn, self.apiroutine):
                        yield m
        for m in self.apiroutine.waitForSend(FlowInitialize(conn, conn.openflow_datapathid, conn.protocol.vhost)):
            yield m
    def _acquire_tables(self):
        try:
            while self._acquire_updated:
                result = None
                exception = None
                # Delay the update so we are not updating table acquires for every module
                for m in self.apiroutine.waitForSend(TableAcquireDelayEvent()):
                    yield m
                yield (TableAcquireDelayEvent.createMatcher(),)
                module_list = list(self.table_modules)
                self._acquire_updated = False
                try:
                    for m in self.apiroutine.executeAll((callAPI(self.apiroutine, module, 'gettablerequest', {}) for module in module_list)):
                        yield m
                except QuitException:
                    raise
                except Exception as exc:
                    self._logger.exception('Acquiring table failed')
                    exception = exc
                else:
                    requests = [r[0] for r in self.apiroutine.retvalue]
                    vhosts = set(vh for _, vhs in requests if vhs is not None for vh in vhs)
                    vhost_result = {}
                    # Requests should be list of (name, (ancester, ancester, ...), pathname)
                    for vh in vhosts:
                        graph = {}
                        table_path = {}
                        try:
                            for r in requests:
                                if r[1] is None or vh in r[1]:
                                    for name, ancesters, pathname in r[0]:
                                        if name in table_path:
                                            if table_path[name] != pathname:
                                                raise ValueError("table conflict detected: %r can not be in two path: %r and %r" % (name, table_path[name], pathname))
                                        else:
                                            table_path[name] = pathname
                                        if name not in graph:
                                            graph[name] = (set(ancesters), set())
                                        else:
                                            graph[name][0].update(ancesters)
                                        for anc in ancesters:
                                            graph.setdefault(anc, (set(), set()))[1].add(name)
                        except ValueError as exc:
                            self._logger.error(str(exc))
                            exception = exc
                            break
                        else:
                            sequences = []
                            def dfs_sort(current):
                                sequences.append(current)
                                for d in graph[current][1]:
                                    anc = graph[d][0]
                                    anc.remove(current)
                                    if not anc:
                                        dfs_sort(d)
                            nopre_tables = sorted([k for k,v in graph.items() if not v[0]], key = lambda x: (table_path.get(name, ''),name))
                            for t in nopre_tables:
                                dfs_sort(t)
                            if len(sequences) < len(graph):
                                rest_tables = set(graph.keys()).difference(sequences)
                                self._logger.error("Circle detected in table acquiring, following tables are related: %r, vhost = %r", sorted(rest_tables), vh)
                                self._logger.error("Circle dependencies are: %s", ", ".join(repr(tuple(graph[t][0])) + "=>" + t for t in rest_tables))
                                exception = ValueError("Circle detected in table acquiring, following tables are related: %r, vhost = %r" % (sorted(rest_tables),vh))
                                break
                            elif len(sequences) > 255:
                                self._logger.error("Table limit exceeded: %d tables (only 255 allowed), vhost = %r", len(sequences), vh)
                                exception = ValueError("Table limit exceeded: %d tables (only 255 allowed), vhost = %r" % (len(sequences),vh))
                                break
                            else:
                                full_indices = list(zip(sequences, itertools.count()))
                                tables = dict((k,tuple(g)) for k,g in itertools.groupby(sorted(full_indices, key = lambda x: table_path.get(x[0], '')),
                                                           lambda x: table_path.get(x[0], '')))
                                vhost_result[vh] = (full_indices, tables)
        finally:
            self._acquiring = False
        if exception:
            for m in self.apiroutine.waitForSend(TableAcquireUpdate(exception = exception)):
                yield m
        else:
            result = vhost_result
            if result != self._lastacquire:
                self._lastacquire = result
                self._reinitall()
            for m in self.apiroutine.waitForSend(TableAcquireUpdate(result = result)):
                yield m
    def load(self, container):
        self.scheduler.queue.addSubQueue(1, TableAcquireDelayEvent.createMatcher(), 'ofpmanager_tableacquiredelay')
        for m in container.waitForSend(TableAcquireUpdate(result = None)):
            yield m
        for m in Module.load(self, container):
            yield m
    def unload(self, container, force=False):
        for m in Module.unload(self, container, force=force):
            yield m
        for m in container.syscall(syscall_removequeue(self.scheduler.queue, 'ofpmanager_tableacquiredelay')):
            yield m
    def _reinitall(self):
        for cl in self.managed_conns.values():
            for c in cl:
                self.apiroutine.subroutine(self._initialize_connection(c))
    def _manage_existing(self):
        for m in callAPI(self.apiroutine, "openflowserver", "getconnections", {}):
            yield m
        vb = self.vhostbind
        for c in self.apiroutine.retvalue:
            if vb is None or c.protocol.vhost in vb:
                self._add_connection(c)
        self._synchronized = True
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m
    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),)
    def _manage_conns(self):
        vb = self.vhostbind
        self.apiroutine.subroutine(self._manage_existing(), False)
        try:
            if vb is not None:
                conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
                conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
            else:
                conn_up = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_SETUP)
                conn_down = OpenflowConnectionStateEvent.createMatcher(state = OpenflowConnectionStateEvent.CONNECTION_DOWN)
            while True:
                yield (conn_up, conn_down)
                if self.apiroutine.matcher is conn_up:
                    e = self.apiroutine.event
                    remove = self._add_connection(e.connection)
                    self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [e.connection], remove = remove))
                else:
                    e = self.apiroutine.event
                    conns = self.managed_conns.get((e.createby.vhost, e.datapathid))
                    remove = []
                    if conns is not None:
                        try:
                            conns.remove(e.connection)
                        except ValueError:
                            pass
                        else:
                            remove.append(e.connection)
                        
                        if not conns:
                            del self.managed_conns[(e.createby.vhost, e.datapathid)]
                        # Also delete from endpoint_conns
                        ep = _get_endpoint(e.connection)
                        econns = self.endpoint_conns.get((e.createby.vhost, ep))
                        if econns is not None:
                            try:
                                econns.remove(e.connection)
                            except ValueError:
                                pass
                            if not econns:
                                del self.endpoint_conns[(e.createby.vhost, ep)]
                    if remove:
                        self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'update', add = [], remove = remove))
        finally:
            self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized'))
    def getconnections(self, datapathid, vhost = ''):
        "Return all connections of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.get((vhost, datapathid), []))
    def getconnection(self, datapathid, auxiliaryid = 0, vhost = ''):
        "Get current connection of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getconnection(datapathid, auxiliaryid, vhost)
    def _getconnection(self, datapathid, auxiliaryid = 0, vhost = ''):
        conns = self.managed_conns.get((vhost, datapathid))
        if conns is None:
            return None
        else:
            for c in conns:
                if c.openflow_auxiliaryid == auxiliaryid:
                    return c
            return None
    def waitconnection(self, datapathid, auxiliaryid = 0, timeout = 30, vhost = ''):
        "Wait for a datapath connection"
        for m in self._wait_for_sync():
            yield m
        c = self._getconnection(datapathid, auxiliaryid, vhost)
        if c is None:
            for m in self.apiroutine.waitWithTimeout(timeout, 
                            OpenflowConnectionStateEvent.createMatcher(datapathid, auxiliaryid,
                                    OpenflowConnectionStateEvent.CONNECTION_SETUP,
                                    _ismatch = lambda x: x.createby.vhost == vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath %016x is not connected' % datapathid)
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c
    def getdatapathids(self, vhost = ''):
        "Get All datapath IDs"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost]
    def getalldatapathids(self):
        "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.keys())
    def getallconnections(self, vhost = ''):
        "Get all connections from vhost. If vhost is None, return all connections from any host"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = list(itertools.chain(self.managed_conns.values()))
        else:
            self.apiroutine.retvalue = list(itertools.chain(v for k,v in self.managed_conns.items() if k[0] == vhost))
    def getconnectionsbyendpoint(self, endpoint, vhost = ''):
        "Get connection by endpoint address (IP, IPv6 or UNIX socket address)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint))
    def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30):
        "Get connection by endpoint name (Domain name, IP or IPv6 address)"
        # Resolve the name
        if not name:
            endpoint = ''
            for m in self.getconnectionbyendpoint(endpoint, vhost):
                yield m
        else:
            request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED)
            # Resolve hostname
            for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)):
                yield m
            for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)):
                yield m
            if self.apiroutine.timeout:
                # Resolve is only allowed through asynchronous resolver
                #try:
                #    self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST)
                #except:
                raise IOError('Resolve hostname timeout: ' + name)
            else:
                if hasattr(self.apiroutine.event, 'error'):
                    raise IOError('Cannot resolve hostname: ' + name)
                resp = self.apiroutine.event.response
                for r in resp:
                    raddr = r[4]
                    if isinstance(raddr, tuple):
                        # Ignore port
                        endpoint = raddr[0]
                    else:
                        # Unix socket? This should not happen, but in case...
                        endpoint = raddr
                    for m in self.getconnectionsbyendpoint(endpoint, vhost):
                        yield m
                    if self.apiroutine.retvalue is not None:
                        break
    def getendpoints(self, vhost = ''):
        "Get all endpoints for vhost"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost]
    def getallendpoints(self):
        "Get all endpoints from any vhost. Return (vhost, endpoint) pairs."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.endpoint_conns.keys())
    def lastacquiredtables(self, vhost = ""):
        "Get acquired table IDs"
        return self._lastacquire.get(vhost)
    def acquiretable(self, modulename):
        "Start to acquire tables for a module on module loading."
        if not modulename in self.table_modules:
            self.table_modules.add(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield
    def unacquiretable(self, modulename):
        "When module is unloaded, stop acquiring tables for this module."
        if modulename in self.table_modules:
            self.table_modules.remove(modulename)
            self._acquire_updated = True
            if not self._acquiring:
                self._acquiring = True
                self.apiroutine.subroutine(self._acquire_tables())
        self.apiroutine.retvalue = None
        if False:
            yield
示例#5
0
class ObjectDB(Module):
    """
    Abstract transaction layer for KVDB
    """
    service = True
    # Priority for object update event
    _default_objectupdatepriority = 450
    # Enable debugging mode for updater: all updaters will be called for an extra time
    # to make sure it does not crash with multiple calls
    _default_debuggingupdater = False
    def __init__(self, server):
        Module.__init__(self, server)
        self._managed_objs = {}
        self._watches = {}
        self._watchedkeys = set()
        self._requests = []
        self._transactno = 0
        self._stale = False
        self._updatekeys = set()
        self._update_version = {}
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._update
        self.routines.append(self.apiroutine)
        self.createAPI(api(self.mget, self.apiroutine),
                       api(self.get, self.apiroutine),
                       api(self.mgetonce, self.apiroutine),
                       api(self.getonce, self.apiroutine),
                       api(self.mwatch, self.apiroutine),
                       api(self.watch, self.apiroutine),
                       api(self.munwatch, self.apiroutine),
                       api(self.unwatch, self.apiroutine),
                       api(self.unwatchall, self.apiroutine),
                       api(self.transact, self.apiroutine),
                       api(self.watchlist),
                       api(self.walk, self.apiroutine)
                       )
    def load(self, container):
        self.scheduler.queue.addSubQueue(\
                self.objectupdatepriority, dataobj.DataObjectUpdateEvent.createMatcher(), 'dataobjectupdate')
        for m in callAPI(container, 'updatenotifier', 'createnotifier'):
            yield m
        self._notifier = container.retvalue
        for m in Module.load(self, container):
            yield m
        self.routines.append(self._notifier)
    def unload(self, container, force=False):
        for m in container.syscall(syscall_removequeue(self.scheduler.queue, 'dataobjectupdate')):
            yield m
        for m in Module.unload(self, container, force=force):
            yield m
    def _update(self):
        timestamp = '%012x' % (int(time() * 1000),) + '-'
        notification_matcher = self._notifier.notification_matcher(False)
        def copywithkey(obj, key):
            newobj = deepcopy(obj)
            if hasattr(newobj, 'setkey'):
                newobj.setkey(key)
            return newobj
        def getversion(obj):
            if obj is None:
                return (0, -1)
            else:
                return (getattr(obj, 'kvdb_createtime', 0), getattr(obj, 'kvdb_updateversion', 0))
        def isnewer(obj, version):
            if obj is None:
                return version[1] != -1
            else:
                return getversion(obj) > version
        request_matcher = RetrieveRequestSend.createMatcher()
        def onupdate(event, matcher):
            update_keys = self._watchedkeys.intersection([_str(k) for k in event.keys])
            self._updatekeys.update(update_keys)
            if event.extrainfo:
                for k,v in zip(event.keys, event.extrainfo):
                    k = _str(k)
                    if k in update_keys:
                        v = tuple(v)
                        oldv = self._update_version.get(k, (0, -1))
                        if oldv < v:
                            self._update_version[k] = v
            else:
                for k in event.keys:
                    try:
                        del self._update_version[_str(k)]
                    except KeyError:
                        pass
        def updateinner():
            processing_requests = []
            # New managed keys
            retrieve_list = set()
            orig_retrieve_list = set()
            retrieveonce_list = set()
            orig_retrieveonce_list = set()
            # Retrieved values are stored in update_result before merging into current storage
            update_result = {}
            # key => [(walker_func, original_keys, rid), ...]
            walkers = {}
            self._loopCount = 0
            # A request-id -> retrieve set dictionary to store the saved keys
            savelist = {}
            def updateloop():
                while (retrieve_list or self._updatekeys or self._requests):
                    watch_keys = set()
                    # Updated keys
                    update_list = set()
                    if self._loopCount >= 10 and not retrieve_list:
                        if not self._updatekeys:
                            break
                        elif self._loopCount >= 100:
                            # Too many updates, we must stop to respond
                            self._logger.warning("There are still database updates after 100 loops of mget, respond with potential inconsistent values")
                            break
                    if self._updatekeys:
                        update_list.update(self._updatekeys)
                        self._updatekeys.clear()
                    if self._requests:
                        # Processing requests
                        for r in self._requests:
                            if r[2] == 'unwatch':
                                try:
                                    for k in r[0]:
                                        s = self._watches.get(k)
                                        if s:
                                            s.discard(r[3])
                                            if not s:
                                                del self._watches[k]
                                    # Do not need to wait
                                except Exception as exc:
                                    for m in self.apiroutine.waitForSend(RetrieveReply(r[1], exception = exc)):
                                        yield m                                    
                                else:
                                    for m in self.apiroutine.waitForSend(RetrieveReply(r[1], result = None)):
                                        yield m
                            elif r[2] == 'watch':
                                retrieve_list.update(r[0])
                                orig_retrieve_list.update(r[0])
                                for k in r[0]:
                                    self._watches.setdefault(k, set()).add(r[3])
                                processing_requests.append(r)
                            elif r[2] == 'get':
                                retrieve_list.update(r[0])
                                orig_retrieve_list.update(r[0])
                                processing_requests.append(r)
                            elif r[2] == 'walk':
                                retrieve_list.update(r[0])
                                processing_requests.append(r)
                                for k,v in r[3].items():
                                    walkers.setdefault(k, []).append((v, (r[0], r[1])))
                            else:
                                retrieveonce_list.update(r[0])
                                orig_retrieveonce_list.update(r[0])
                                processing_requests.append(r)
                        del self._requests[:]
                    if retrieve_list:
                        watch_keys.update(retrieve_list)
                    # Add watch_keys to notification
                    watch_keys.difference_update(self._watchedkeys)
                    if watch_keys:
                        for k in watch_keys:
                            if k in update_result:
                                self._update_version[k] = getversion(update_result[k])
                        for m in self._notifier.add_listen(*tuple(watch_keys.difference(self._watchedkeys))):
                            yield m
                        self._watchedkeys.update(watch_keys)
                    get_list_set = update_list.union(retrieve_list.union(retrieveonce_list).difference(self._managed_objs.keys()).difference(update_result.keys()))
                    get_list = list(get_list_set)
                    if get_list:
                        try:
                            for m in callAPI(self.apiroutine, 'kvstorage', 'mget', {'keys': get_list}):
                                yield m
                        except QuitException:
                            raise
                        except Exception:
                            # Serve with cache
                            if not self._stale:
                                self._logger.warning('KVStorage retrieve failed, serve with cache', exc_info = True)
                            self._stale = True
                            # Discard all retrieved results
                            update_result.clear()
                            # Retry update later
                            self._updatekeys.update(update_list)
                            #break
                            changed_set = set()
                        else:
                            result = self.apiroutine.retvalue
                            self._stale = False
                            for k,v in zip(get_list, result):
                                if v is not None and hasattr(v, 'setkey'):
                                    v.setkey(k)
                                if k in self._watchedkeys and k not in self._update_version:
                                    self._update_version[k] = getversion(v)
                            changed_set = set(k for k,v in zip(get_list, result) if k not in update_result or getversion(v) != getversion(update_result[k]))
                            update_result.update(zip(get_list, result))
                    else:
                        changed_set = set()
                    # All keys which should be retrieved in next loop
                    new_retrieve_list = set()
                    # Keys which should be retrieved in next loop for a single walk
                    new_retrieve_keys = set()
                    # Keys that are used in current walk will be retrieved again in next loop
                    used_keys = set()
                    # We separate the original data and new retrieved data space, and do not allow
                    # cross usage, to prevent discontinue results 
                    def walk_original(key):
                        if hasattr(key, 'getkey'):
                            key = key.getkey()
                        key = _str(key)
                        if key not in self._watchedkeys:
                            # This key is not retrieved, raise a KeyError, and record this key
                            new_retrieve_keys.add(key)
                            raise KeyError('Not retrieved')
                        elif self._stale:
                            if key not in self._managed_objs:
                                new_retrieve_keys.add(key)
                            else:
                                used_keys.add(key)
                            return self._managed_objs.get(key)
                        elif key in changed_set:
                            # We are retrieving from the old result, do not allow to use new data
                            used_keys.add(key)
                            new_retrieve_keys.add(key)
                            raise KeyError('Not retrieved')
                        elif key in update_result:
                            used_keys.add(key)
                            return update_result[key]
                        elif key in self._managed_objs:
                            used_keys.add(key)
                            return self._managed_objs[key]
                        else:
                            # This key is not retrieved, raise a KeyError, and record this key
                            new_retrieve_keys.add(key)
                            raise KeyError('Not retrieved')
                    def walk_new(key):
                        if hasattr(key, 'getkey'):
                            key = key.getkey()
                        key = _str(key)
                        if key not in self._watchedkeys:
                            # This key is not retrieved, raise a KeyError, and record this key
                            new_retrieve_keys.add(key)
                            raise KeyError('Not retrieved')
                        elif key in get_list_set:
                            # We are retrieving from the new data
                            used_keys.add(key)
                            return update_result[key]
                        elif key in self._managed_objs or key in update_result:
                            # Do not allow the old data
                            used_keys.add(key)
                            new_retrieve_keys.add(key)
                            raise KeyError('Not retrieved')
                        else:
                            # This key is not retrieved, raise a KeyError, and record this key
                            new_retrieve_keys.add(key)
                            raise KeyError('Not retrieved')
                    def create_walker(orig_key):
                        if self._stale:
                            return walk_original
                        elif orig_key in changed_set:
                            return walk_new
                        else:
                            return walk_original
                    walker_set = set()
                    def default_walker(key, obj, walk):
                        if key in walker_set:
                            return
                        else:
                            walker_set.add(key)
                        if hasattr(obj, 'kvdb_retrievelist'):
                            rl = obj.kvdb_retrievelist()
                            for k in rl:
                                try:
                                    newobj = walk(k)
                                except KeyError:
                                    pass
                                else:
                                    if newobj is not None:
                                        default_walker(k, newobj, walk)
                    for k in orig_retrieve_list:
                        v = update_result.get(k)
                        if v is not None:
                            new_retrieve_keys.clear()
                            used_keys.clear()
                            default_walker(k, v, create_walker(k))
                            if new_retrieve_keys:
                                new_retrieve_list.update(new_retrieve_keys)
                                self._updatekeys.update(used_keys)
                                self._updatekeys.add(k)
                    savelist.clear()
                    for k,ws in walkers.items():
                        # k: the walker key
                        # ws: list of [walker_func, (request_original_keys, rid)]
                        # Retry every walker, starts with k, with the value of v
                        if k in update_result:
                            # The value is newly retrieved
                            v = update_result.get(k)
                        else:
                            # Use the stored value
                            v = self._managed_objs.get(k)
                        if ws:
                            for w,r in list(ws):
                                # w: walker_func
                                # r: (request_original_keys, rid)
                                # Custom walker
                                def save(key):
                                    if hasattr(key, 'getkey'):
                                        key = key.getkey()
                                    key = _str(key)
                                    if key != k and key not in used_keys:
                                        raise ValueError('Cannot save a key without walk')
                                    savelist.setdefault(r[1], set()).add(key)
                                try:
                                    new_retrieve_keys.clear()
                                    used_keys.clear()
                                    w(k, v, create_walker(k), save)
                                except Exception as exc:
                                    # if one walker failed, the whole request is failed, remove all walkers
                                    self._logger.warning("A walker raises an exception which rolls back the whole walk process %r. "
                                                         "walker = %r, start key = %r, new_retrieve_keys = %r, used_keys = %r",
                                                         w, k, r[1], new_retrieve_keys, used_keys, exc_info=True)
                                    for orig_k in r[0]:
                                        if orig_k in walkers:
                                            walkers[orig_k][:] = [(w0, r0) for w0,r0 in walkers[orig_k] if r0[1] != r[1]]
                                    processing_requests[:] = [r0 for r0 in processing_requests if r0[1] != r[1]]
                                    savelist.pop(r[1])
                                    for m in self.apiroutine.waitForSend(RetrieveReply(r[1], exception = exc)):
                                        yield m
                                else:
                                    if new_retrieve_keys:
                                        new_retrieve_list.update(new_retrieve_keys)
                                        self._updatekeys.update(used_keys)
                                        self._updatekeys.add(k)
                    for save in savelist.values():
                        for k in save:
                            v = update_result.get(k)
                            if v is not None:
                                # If we retrieved a new value, we should also retrieved the references
                                # from this value
                                new_retrieve_keys.clear()
                                used_keys.clear()
                                default_walker(k, v, create_walker(k))
                                if new_retrieve_keys:
                                    new_retrieve_list.update(new_retrieve_keys)
                                    self._updatekeys.update(used_keys)
                                    self._updatekeys.add(k)                            
                    retrieve_list.clear()
                    retrieveonce_list.clear()
                    retrieve_list.update(new_retrieve_list)
                    self._loopCount += 1
                    if self._stale:
                        watch_keys = set(retrieve_list)
                        watch_keys.difference_update(self._watchedkeys)
                        if watch_keys:
                            for m in self._notifier.add_listen(*tuple(watch_keys)):
                                yield m
                            self._watchedkeys.update(watch_keys)
                        break
            while True:
                for m in self.apiroutine.withCallback(updateloop(), onupdate, notification_matcher):
                    yield m
                if self._loopCount >= 100 or self._stale:
                    break
                # If some updated result is newer than the notification version, we should wait for the notification
                should_wait = False
                for k,v in update_result.items():
                    if k in self._watchedkeys:
                        oldv = self._update_version.get(k)
                        if oldv is not None and isnewer(v, oldv):
                            should_wait = True
                            break
                if should_wait:
                    for m in self.apiroutine.waitWithTimeout(0.2, notification_matcher):
                        yield m
                    if self.apiroutine.timeout:
                        break
                    else:
                        onupdate(self.apiroutine.event, self.apiroutine.matcher)
                else:
                    break
            # Update result
            send_events = []
            self._transactno += 1
            transactid = '%s%016x' % (timestamp, self._transactno)
            update_objs = []
            for k,v in update_result.items():
                if k in self._watchedkeys:
                    if v is None:
                        oldv = self._managed_objs.get(k)
                        if oldv is not None:
                            if hasattr(oldv, 'kvdb_detach'):
                                oldv.kvdb_detach()
                                update_objs.append((k, oldv, dataobj.DataObjectUpdateEvent.DELETED))
                            else:
                                update_objs.append((k, None, dataobj.DataObjectUpdateEvent.DELETED))
                            del self._managed_objs[k]
                    else:
                        oldv = self._managed_objs.get(k)
                        if oldv is not None:
                            if oldv != v:
                                if oldv and hasattr(oldv, 'kvdb_update'):
                                    oldv.kvdb_update(v)
                                    update_objs.append((k, oldv, dataobj.DataObjectUpdateEvent.UPDATED))
                                else:
                                    if hasattr(oldv, 'kvdb_detach'):
                                        oldv.kvdb_detach()
                                    self._managed_objs[k] = v
                                    update_objs.append((k, v, dataobj.DataObjectUpdateEvent.UPDATED))
                        else:
                            self._managed_objs[k] = v
                            update_objs.append((k, v, dataobj.DataObjectUpdateEvent.UPDATED))
            for k in update_result.keys():
                v = self._managed_objs.get(k)
                if v is not None and hasattr(v, 'kvdb_retrievefinished'):
                    v.kvdb_retrievefinished(self._managed_objs)
            allkeys = tuple(k for k,_,_ in update_objs)
            send_events.extend((dataobj.DataObjectUpdateEvent(k, transactid, t, object = v, allkeys = allkeys) for k,v,t in update_objs))
            # Process requests
            for r in processing_requests:
                if r[2] == 'get':
                    objs = [self._managed_objs.get(k) for k in r[0]]
                    for k,v in zip(r[0], objs):
                        if v is not None:
                            self._watches.setdefault(k, set()).add(r[3])
                    result = [o.create_reference() if o is not None and hasattr(o, 'create_reference') else o
                              for o in objs]
                elif r[2] == 'watch':
                    result = [(v.create_reference() if hasattr(v, 'create_reference') else v)
                              if v is not None else dataobj.ReferenceObject(k)
                              for k,v in ((k,self._managed_objs.get(k)) for k in r[0])]
                elif r[2] == 'walk':
                    saved_keys = list(savelist.get(r[1], []))
                    for k in saved_keys:
                        self._watches.setdefault(k, set()).add(r[4])
                    objs = [self._managed_objs.get(k) for k in saved_keys]
                    result = (saved_keys,
                              [o.create_reference() if hasattr(o, 'create_reference') else o
                               if o is not None else dataobj.ReferenceObject(k)
                               for k,o in zip(saved_keys, objs)])
                else:
                    result = [copywithkey(update_result.get(k, self._managed_objs.get(k)), k) for k in r[0]]
                send_events.append(RetrieveReply(r[1], result = result, stale = self._stale))
            # Use DFS to remove unwatched objects
            mark_set = set()
            def dfs(k):
                if k in mark_set:
                    return
                mark_set.add(k)
                v = self._managed_objs.get(k)
                if v is not None and hasattr(v, 'kvdb_internalref'):
                    for k2 in v.kvdb_internalref():
                        dfs(k2)
            for k in self._watches.keys():
                dfs(k)
            def output_result():
                remove_keys = self._watchedkeys.difference(mark_set)
                if remove_keys:
                    self._watchedkeys.difference_update(remove_keys)
                    for m in self._notifier.remove_listen(*tuple(remove_keys)):
                        yield m
                    for k in remove_keys:
                        if k in self._managed_objs:
                            del self._managed_objs[k]
                        if k in self._update_version:
                            del self._update_version[k]
                for e in send_events:
                    for m in self.apiroutine.waitForSend(e):
                        yield m
            for m in self.apiroutine.withCallback(output_result(), onupdate):
                yield m
        while True:
            if not self._updatekeys and not self._requests:
                yield (notification_matcher, request_matcher)
                if self.apiroutine.matcher is notification_matcher:
                    onupdate(self.apiroutine.event, self.apiroutine.matcher)
            for m in updateinner():
                yield m
    def mget(self, keys, requestid, nostale = False):
        "Get multiple objects and manage them. Return references to the objects."
        keys = tuple(_str2(k) for k in keys)
        notify = not self._requests
        rid = object()
        self._requests.append((keys, rid, 'get', requestid))
        if notify:
            for m in self.apiroutine.waitForSend(RetrieveRequestSend()):
                yield m
        yield (RetrieveReply.createMatcher(rid),)
        if hasattr(self.apiroutine.event, 'exception'):
            raise self.apiroutine.event.exception
        if nostale and self.apiroutine.event.stale:
            raise StaleResultException(self.apiroutine.event.result)
        self.apiroutine.retvalue = self.apiroutine.event.result
    def get(self, key, requestid, nostale = False):
        """
        Get an object from specified key, and manage the object.
        Return a reference to the object or None if not exists.
        """
        for m in self.mget([key], requestid, nostale):
            yield m
        self.apiroutine.retvalue = self.apiroutine.retvalue[0]
    def mgetonce(self, keys, nostale = False):
        "Get multiple objects, return copies of them. Referenced objects are not retrieved."
        keys = tuple(_str2(k) for k in keys)
        notify = not self._requests
        rid = object()
        self._requests.append((keys, rid, 'getonce'))
        if notify:
            for m in self.apiroutine.waitForSend(RetrieveRequestSend()):
                yield m
        yield (RetrieveReply.createMatcher(rid),)
        if hasattr(self.apiroutine.event, 'exception'):
            raise self.apiroutine.event.exception
        if nostale and self.apiroutine.event.stale:
            raise StaleResultException(self.apiroutine.event.result)
        self.apiroutine.retvalue = self.apiroutine.event.result
    def getonce(self, key, nostale = False):
        "Get a object without manage it. Return a copy of the object, or None if not exists. Referenced objects are not retrieved."
        for m in self.mgetonce([key], nostale):
            yield m
        self.apiroutine.retvalue = self.apiroutine.retvalue[0]
    def watch(self, key, requestid, nostale = False):
        """
        Try to find an object and return a reference. Use ``reference.isdeleted()`` to test
        whether the object exists.
        Use ``reference.wait(container)`` to wait for the object to be existed.
        """
        for m in self.mwatch([key], requestid, nostale):
            yield m
        self.apiroutine.retvalue = self.apiroutine.retvalue[0]
    def mwatch(self, keys, requestid, nostale = False):
        "Try to return all the references, see ``watch()``"
        keys = tuple(_str2(k) for k in keys)
        notify = not self._requests
        rid = object()
        self._requests.append(keys, rid, 'watch', requestid)
        if notify:
            for m in self.apiroutine.waitForSend(RetrieveRequestSend()):
                yield m
        yield (RetrieveReply.createMatcher(rid),)
        if hasattr(self.apiroutine.event, 'exception'):
            raise self.apiroutine.event.exception
        if nostale and self.apiroutine.event.stale:
            raise StaleResultException(self.apiroutine.event.result)
        self.apiroutine.retvalue = self.apiroutine.event.result
    def unwatch(self, key, requestid):
        "Cancel management of a key"
        for m in self.munwatch([key], requestid):
            yield m
        self.apiroutine.retvalue = None
    def unwatchall(self, requestid):
        "Cancel management for all keys that are managed by requestid"
        keys = [k for k,v in self._watches.items() if requestid in v]
        for m in self.munwatch(keys, requestid):
            yield m
    def munwatch(self, keys, requestid):
        "Cancel management of keys"
        keys = tuple(_str2(k) for k in keys)
        notify = not self._requests
        rid = object()
        self._requests.append((keys, rid, 'unwatch', requestid))
        if notify:
            for m in self.apiroutine.waitForSend(RetrieveRequestSend()):
                yield m
        yield (RetrieveReply.createMatcher(rid),)
        if hasattr(self.apiroutine.event, 'exception'):
            raise self.apiroutine.event.exception
        self.apiroutine.retvalue = None
    def transact(self, keys, updater, withtime = False):
        """
        Try to update keys in a transact, with an ``updater(keys, values)``,
        which returns ``(updated_keys, updated_values)``.
        
        The updater may be called more than once. If ``withtime = True``,
        the updater should take three parameters:
        ``(keys, values, timestamp)`` with timestamp as the server time
        """
        keys = tuple(_str2(k) for k in keys)
        updated_ref = [None, None]
        extra_keys = []
        extra_key_set = []
        auto_remove_keys = set()
        orig_len = len(keys)
        def updater_with_key(keys, values, timestamp):
            # Automatically manage extra keys
            remove_uniquekeys = []
            remove_multikeys = []
            update_uniquekeys = []
            update_multikeys = []
            keystart = orig_len + len(auto_remove_keys)
            for v in values[:keystart]:
                if v is not None:
                    if hasattr(v, 'kvdb_uniquekeys'):
                        remove_uniquekeys.extend((k,v.create_weakreference()) for k in v.kvdb_uniquekeys())
                    if hasattr(v, 'kvdb_multikeys'):
                        remove_multikeys.extend((k,v.create_weakreference()) for k in v.kvdb_multikeys())
            if self.debuggingupdater:
                # Updater may be called more than once, ensure that this updater does not crash
                # on multiple calls
                kc = keys[:orig_len]
                vc = [v.clone_instance() if v is not None and hasattr(v, 'clone_instance') else deepcopy(v) for v in values[:orig_len]]
                if withtime:
                    updated_keys, updated_values = updater(kc, vc, timestamp)
                else:
                    updated_keys, updated_values = updater(kc, vc)
            if withtime:
                updated_keys, updated_values = updater(keys[:orig_len], values[:orig_len], timestamp)
            else:
                updated_keys, updated_values = updater(keys[:orig_len], values[:orig_len])
            for v in updated_values:
                if v is not None:
                    if hasattr(v, 'kvdb_uniquekeys'):
                        update_uniquekeys.extend((k,v.create_weakreference()) for k in v.kvdb_uniquekeys())
                    if hasattr(v, 'kvdb_multikeys'):
                        update_multikeys.extend((k,v.create_weakreference()) for k in v.kvdb_multikeys())
            extrakeysdict = dict(zip(keys[keystart:keystart + len(extra_keys)], values[keystart:keystart + len(extra_keys)]))
            extrakeysetdict = dict(zip(keys[keystart + len(extra_keys):keystart + len(extra_keys) + len(extra_key_set)],
                                       values[keystart + len(extra_keys):keystart + len(extra_keys) + len(extra_key_set)]))
            tempdict = {}
            old_values = dict(zip(keys, values))
            updated_keyset = set(updated_keys)
            try:
                append_remove = set()
                autoremove_keys = set()
                # Use DFS to find auto remove keys
                def dfs(k):
                    if k in autoremove_keys:
                        return
                    autoremove_keys.add(k)
                    if k not in old_values:
                        append_remove.add(k)
                    else:
                        oldv = old_values[k]
                        if oldv is not None and hasattr(oldv, 'kvdb_autoremove'):
                            for k2 in oldv.kvdb_autoremove():
                                dfs(k2)
                for k,v in zip(updated_keys, updated_values):
                    if v is None:
                        dfs(k)
                if append_remove:
                    raise _NeedMoreKeysException()
                for k,v in remove_uniquekeys:
                    if v.getkey() not in updated_keyset and v.getkey() not in auto_remove_keys:
                        # This key is not updated, keep the indices untouched
                        continue
                    if k not in extrakeysdict:
                        raise _NeedMoreKeysException()
                    elif extrakeysdict[k] is not None and extrakeysdict[k].ref.getkey() == v.getkey():
                        # If the unique key does not reference to the correct object
                        # there may be an error, but we ignore this.
                        # Save in a temporary dictionary. We may restore it later.
                        tempdict[k] = extrakeysdict[k]
                        extrakeysdict[k] = None
                        setkey = UniqueKeyReference.get_keyset_from_key(k)
                        if setkey not in extrakeysetdict:
                            raise _NeedMoreKeysException()
                        else:
                            ks = extrakeysetdict[setkey]
                            if ks is None:
                                ks = UniqueKeySet.create_from_key(setkey)
                                extrakeysetdict[setkey] = ks
                            ks.set.dataset().discard(WeakReferenceObject(k))
                for k,v in remove_multikeys:
                    if v.getkey() not in updated_keyset and v.getkey() not in auto_remove_keys:
                        # This key is not updated, keep the indices untouched
                        continue
                    if k not in extrakeysdict:
                        raise _NeedMoreKeysException()
                    else:
                        mk = extrakeysdict[k]
                        if mk is not None:
                            mk.set.dataset().discard(v)
                            if not mk.set.dataset():
                                tempdict[k] = extrakeysdict[k]
                                extrakeysdict[k] = None
                                setkey = MultiKeyReference.get_keyset_from_key(k)
                                if setkey not in extrakeysetdict:
                                    raise _NeedMoreKeysException()
                                else:
                                    ks = extrakeysetdict[setkey]
                                    if ks is None:
                                        ks = MultiKeySet.create_from_key(setkey)
                                        extrakeysetdict[setkey] = ks
                                    ks.set.dataset().discard(WeakReferenceObject(k))
                for k,v in update_uniquekeys:
                    if k not in extrakeysdict:
                        raise _NeedMoreKeysException()
                    elif extrakeysdict[k] is not None and extrakeysdict[k].ref.getkey() != v.getkey():
                        raise AlreadyExistsException('Unique key conflict for %r and %r, with key %r' % \
                                                     (extrakeysdict[k].ref.getkey(), v.getkey(), k))
                    elif extrakeysdict[k] is None:
                        lv = tempdict.get(k, None)
                        if lv is not None and lv.ref.getkey() == v.getkey():
                            # Restore this value
                            nv = lv
                        else:
                            nv = UniqueKeyReference.create_from_key(k)
                            nv.ref = ReferenceObject(v.getkey())
                        extrakeysdict[k] = nv
                        setkey = UniqueKeyReference.get_keyset_from_key(k)
                        if setkey not in extrakeysetdict:
                            raise _NeedMoreKeysException()
                        else:
                            ks = extrakeysetdict[setkey]
                            if ks is None:
                                ks = UniqueKeySet.create_from_key(setkey)
                                extrakeysetdict[setkey] = ks
                            ks.set.dataset().add(nv.create_weakreference())
                for k,v in update_multikeys:
                    if k not in extrakeysdict:
                        raise _NeedMoreKeysException()
                    else:
                        mk = extrakeysdict[k]
                        if mk is None:
                            mk = tempdict.get(k, None)
                            if mk is None:
                                mk = MultiKeyReference.create_from_key(k)
                                mk.set = DataObjectSet()
                            setkey = MultiKeyReference.get_keyset_from_key(k)
                            if setkey not in extrakeysetdict:
                                raise _NeedMoreKeysException()
                            else:
                                ks = extrakeysetdict[setkey]
                                if ks is None:
                                    ks = MultiKeySet.create_from_key(setkey)
                                    extrakeysetdict[setkey] = ks
                                ks.set.dataset().add(mk.create_weakreference())
                        mk.set.dataset().add(v)
                        extrakeysdict[k] = mk
            except _NeedMoreKeysException:
                # Prepare the keys
                extra_keys[:] = list(set(itertools.chain((k for k,v in remove_uniquekeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys),
                                                         (k for k,v in remove_multikeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys),
                                                         (k for k,_ in update_uniquekeys),
                                                         (k for k,_ in update_multikeys))))
                extra_key_set[:] = list(set(itertools.chain((UniqueKeyReference.get_keyset_from_key(k) for k,v in remove_uniquekeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys),
                                                         (MultiKeyReference.get_keyset_from_key(k) for k,v in remove_multikeys if v.getkey() in updated_keyset or v.getkey() in autoremove_keys),
                                                         (UniqueKeyReference.get_keyset_from_key(k) for k,_ in update_uniquekeys),
                                                         (MultiKeyReference.get_keyset_from_key(k) for k,_ in update_multikeys))))
                auto_remove_keys.clear()
                auto_remove_keys.update(autoremove_keys.difference(keys[:orig_len])
                                                          .difference(extra_keys)
                                                          .difference(extra_key_set))
                raise
            else:
                extrakeys_list = list(extrakeysdict.items())
                extrakeyset_list = list(extrakeysetdict.items())
                autoremove_list = list(autoremove_keys.difference(updated_keys)
                                                      .difference(extrakeysdict.keys())
                                                      .difference(extrakeysetdict.keys()))
                return (tuple(itertools.chain(updated_keys,
                                              (k for k,_ in extrakeys_list),
                                              (k for k,_ in extrakeyset_list),
                                              autoremove_list)),
                        tuple(itertools.chain(updated_values,
                                               (v for _,v in extrakeys_list),
                                               (v for _,v in extrakeyset_list),
                                               [None] * len(autoremove_list))))
                        
        def object_updater(keys, values, timestamp):
            old_version = {}
            for k, v in zip(keys, values):
                if v is not None and hasattr(v, 'setkey'):
                    v.setkey(k)
                if v is not None and hasattr(v, 'kvdb_createtime'):
                    old_version[k] = (getattr(v, 'kvdb_createtime'), getattr(v, 'kvdb_updateversion', 1))
            updated_keys, updated_values = updater_with_key(keys, values, timestamp)
            updated_ref[0] = tuple(updated_keys)
            new_version = []
            for k,v in zip(updated_keys, updated_values):
                if v is None:
                    new_version.append((timestamp, -1))
                elif k in old_version:
                    ov = old_version[k]
                    setattr(v, 'kvdb_createtime', ov[0])
                    setattr(v, 'kvdb_updateversion', ov[1] + 1)
                    new_version.append((ov[0], ov[1] + 1))
                else:
                    setattr(v, 'kvdb_createtime', timestamp)
                    setattr(v, 'kvdb_updateversion', 1)
                    new_version.append((timestamp, 1))
            updated_ref[1] = new_version
            return (updated_keys, updated_values)
        while True:
            try:
                for m in callAPI(self.apiroutine, 'kvstorage', 'updateallwithtime',
                                 {'keys': keys + tuple(auto_remove_keys) + \
                                         tuple(extra_keys) + tuple(extra_key_set),
                                         'updater': object_updater}):
                    yield m
            except _NeedMoreKeysException:
                pass
            else:
                break
        # Short cut update notification
        update_keys = self._watchedkeys.intersection(updated_ref[0])
        self._updatekeys.update(update_keys)
        for k,v in zip(updated_ref[0], updated_ref[1]):
            k = _str(k)
            if k in update_keys:
                v = tuple(v)
                oldv = self._update_version.get(k, (0, -1))
                if oldv < v:
                    self._update_version[k] = v
        for m in self.apiroutine.waitForSend(RetrieveRequestSend()):
            yield m
        for m in self._notifier.publish(updated_ref[0], updated_ref[1]):
            yield m
    def watchlist(self, requestid = None):
        """
        Return a dictionary whose keys are database keys, and values are lists of request ids.
        Optionally filtered by request id
        """
        return dict((k,list(v)) for k,v in self._watches.items() if requestid is None or requestid in v)
    def walk(self, keys, walkerdict, requestid, nostale = False):
        """
        Recursively retrieve keys with customized functions.
        walkerdict is a dictionary ``key->walker(key, obj, walk, save)``.
        """
        keys = tuple(_str2(k) for k in keys)
        notify = not self._requests
        rid = object()
        self._requests.append((keys, rid, 'walk', dict(walkerdict), requestid))
        if notify:
            for m in self.apiroutine.waitForSend(RetrieveRequestSend()):
                yield m
        yield (RetrieveReply.createMatcher(rid),)
        if hasattr(self.apiroutine.event, 'exception'):
            raise self.apiroutine.event.exception
        if nostale and self.apiroutine.event.stale:
            raise StaleResultException(self.apiroutine.event.result)
        self.apiroutine.retvalue = self.apiroutine.event.result
        
示例#6
0
class OVSDBManager(Module):
    '''
    Manage Openflow Connections
    '''
    service = True
    _default_vhostbind = None
    _default_bridgenames = None
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_conns
        self.routines.append(self.apiroutine)
        self.managed_conns = {}
        self.managed_systemids = {}
        self.managed_bridges = {}
        self.managed_routines = []
        self.endpoint_conns = {}
        self.createAPI(api(self.getconnection, self.apiroutine),
                       api(self.waitconnection, self.apiroutine),
                       api(self.getdatapathids, self.apiroutine),
                       api(self.getalldatapathids, self.apiroutine),
                       api(self.getallconnections, self.apiroutine),
                       api(self.getbridges, self.apiroutine),
                       api(self.getbridge, self.apiroutine),
                       api(self.getbridgebyuuid, self.apiroutine),
                       api(self.waitbridge, self.apiroutine),
                       api(self.waitbridgebyuuid, self.apiroutine),
                       api(self.getsystemids, self.apiroutine),
                       api(self.getallsystemids, self.apiroutine),
                       api(self.getconnectionbysystemid, self.apiroutine),
                       api(self.waitconnectionbysystemid, self.apiroutine),
                       api(self.getconnectionsbyendpoint, self.apiroutine),
                       api(self.getconnectionsbyendpointname, self.apiroutine),
                       api(self.getendpoints, self.apiroutine),
                       api(self.getallendpoints, self.apiroutine),
                       api(self.getallbridges, self.apiroutine),
                       api(self.getbridgeinfo, self.apiroutine),
                       api(self.waitbridgeinfo, self.apiroutine)
                       )
        self._synchronized = False
    def _update_bridge(self, connection, protocol, bridge_uuid, vhost):
        try:
            method, params = ovsdb.transact('Open_vSwitch',
                                            ovsdb.wait('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]],
                                                        ["datapath_id"], [{"datapath_id": ovsdb.oset()}], False, 5000),
                                            ovsdb.select('Bridge', [["_uuid", "==", ovsdb.uuid(bridge_uuid)]],
                                                                         ["datapath_id","name"]))
            for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                yield m
            r = self.apiroutine.jsonrpc_result[0]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error']))
            r = self.apiroutine.jsonrpc_result[1]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring datapath-id: ' + repr(r['error']))
            if r['rows']:
                r0 = r['rows'][0]
                name = r0['name']
                dpid = int(r0['datapath_id'], 16)
                if self.bridgenames is None or name in self.bridgenames:
                    self.managed_bridges[connection].append((vhost, dpid, name, bridge_uuid))
                    self.managed_conns[(vhost, dpid)] = connection
                    for m in self.apiroutine.waitForSend(OVSDBBridgeSetup(OVSDBBridgeSetup.UP,
                                                               dpid,
                                                               connection.ovsdb_systemid,
                                                               name,
                                                               connection,
                                                               connection.connmark,
                                                               vhost,
                                                               bridge_uuid)):
                        yield m
        except JsonRPCProtocolException:
            pass
    def _get_bridges(self, connection, protocol):
        try:
            try:
                vhost = protocol.vhost
                if not hasattr(connection, 'ovsdb_systemid'):
                    method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Open_vSwitch', [], ['external_ids']))
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                        yield m
                    result = self.apiroutine.jsonrpc_result[0]
                    system_id = ovsdb.omap_getvalue(result['rows'][0]['external_ids'], 'system-id')
                    connection.ovsdb_systemid = system_id
                else:
                    system_id = connection.ovsdb_systemid
                if (vhost, system_id) in self.managed_systemids:
                    oc = self.managed_systemids[(vhost, system_id)]
                    ep = _get_endpoint(oc)
                    econns = self.endpoint_conns.get((vhost, ep))
                    if econns:
                        try:
                            econns.remove(oc)
                        except ValueError:
                            pass
                    del self.managed_systemids[(vhost, system_id)]
                self.managed_systemids[(vhost, system_id)] = connection
                self.managed_bridges[connection] = []
                ep = _get_endpoint(connection)
                self.endpoint_conns.setdefault((vhost, ep), []).append(connection)
                method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])})
                for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                    yield m
                if 'error' in self.apiroutine.jsonrpc_result:
                    # The monitor is already set, cancel it first
                    method, params = ovsdb.monitor_cancel('ovsdb_manager_bridges_monitor')
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine, False):
                        yield m
                    method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_manager_bridges_monitor', {'Bridge':ovsdb.monitor_request(['name', 'datapath_id'])})
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                        yield m
                    if 'error' in self.apiroutine.jsonrpc_result:
                        raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result))
            except Exception:
                for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)):
                    yield m
                raise
            else:
                # Process initial bridges
                init_subprocesses = [self._update_bridge(connection, protocol, buuid, vhost)
                                    for buuid in self.apiroutine.jsonrpc_result['Bridge'].keys()]
                def init_process():
                    try:
                        with closing(self.apiroutine.executeAll(init_subprocesses, retnames = ())) as g:
                            for m in g:
                                yield m
                    except Exception:
                        for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)):
                            yield m
                        raise
                    else:
                        for m in self.apiroutine.waitForSend(OVSDBConnectionSetup(system_id, connection, connection.connmark, vhost)):
                            yield m
                self.apiroutine.subroutine(init_process())
            # Wait for notify
            notification = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark, _ismatch = lambda x: x.params[0] == 'ovsdb_manager_bridges_monitor')
            conn_down = protocol.statematcher(connection)
            while True:
                yield (conn_down, notification)
                if self.apiroutine.matcher is conn_down:
                    break
                else:
                    for buuid, v in self.apiroutine.event.params[1]['Bridge'].items():
                        # If a bridge's name or datapath-id is changed, we remove this bridge and add it again
                        if 'old' in v:
                            # A bridge is deleted
                            bridges = self.managed_bridges[connection]
                            for i in range(0, len(bridges)):
                                if buuid == bridges[i][3]:
                                    self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN,
                                                                               bridges[i][1],
                                                                               system_id,
                                                                               bridges[i][2],
                                                                               connection,
                                                                               connection.connmark,
                                                                               vhost,
                                                                               bridges[i][3],
                                                                               new_datapath_id =
                                                                                int(v['new']['datapath_id'], 16) if 'new' in v and 'datapath_id' in v['new']
                                                                                else None))
                                    del self.managed_conns[(vhost, bridges[i][1])]
                                    del bridges[i]
                                    break
                        if 'new' in v:
                            # A bridge is added
                            self.apiroutine.subroutine(self._update_bridge(connection, protocol, buuid, vhost))
        except JsonRPCProtocolException:
            pass
        finally:
            del connection._ovsdb_manager_get_bridges
    def _manage_existing(self):
        for m in callAPI(self.apiroutine, "jsonrpcserver", "getconnections", {}):
            yield m
        vb = self.vhostbind
        conns = self.apiroutine.retvalue
        for c in conns:
            if vb is None or c.protocol.vhost in vb:
                if not hasattr(c, '_ovsdb_manager_get_bridges'):
                    c._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(c, c.protocol))
        matchers = [OVSDBConnectionSetup.createMatcher(None, c, c.connmark) for c in conns]
        for m in self.apiroutine.waitForAll(*matchers):
            yield m
        self._synchronized = True
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m
    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),)    
    def _manage_conns(self):
        try:
            self.apiroutine.subroutine(self._manage_existing())
            vb = self.vhostbind
            if vb is not None:
                conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
                conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN,
                                                                     _ismatch = lambda x: x.createby.vhost in vb)
            else:
                conn_up = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_UP)
                conn_down = JsonRPCConnectionStateEvent.createMatcher(state = JsonRPCConnectionStateEvent.CONNECTION_DOWN)
            while True:
                yield (conn_up, conn_down)
                if self.apiroutine.matcher is conn_up:
                    if not hasattr(self.apiroutine.event.connection, '_ovsdb_manager_get_bridges'):
                        self.apiroutine.event.connection._ovsdb_manager_get_bridges = self.apiroutine.subroutine(self._get_bridges(self.apiroutine.event.connection, self.apiroutine.event.createby))
                else:
                    e = self.apiroutine.event
                    conn = e.connection
                    bridges = self.managed_bridges.get(conn)
                    if bridges is not None:
                        del self.managed_systemids[(e.createby.vhost, conn.ovsdb_systemid)]
                        del self.managed_bridges[conn]
                        for vhost, dpid, name, buuid in bridges:
                            del self.managed_conns[(vhost, dpid)]
                            self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN,
                                                                       dpid,
                                                                       conn.ovsdb_systemid,
                                                                       name,
                                                                       conn,
                                                                       conn.connmark,
                                                                       e.createby.vhost,
                                                                       buuid))
                        econns = self.endpoint_conns.get(_get_endpoint(conn))
                        if econns is not None:
                            try:
                                econns.remove(conn)
                            except ValueError:
                                pass
        finally:
            for c in self.managed_bridges.keys():
                if hasattr(c, '_ovsdb_manager_get_bridges'):
                    c._ovsdb_manager_get_bridges.close()
                bridges = self.managed_bridges.get(c)
                if bridges is not None:
                    for vhost, dpid, name, buuid in bridges:
                        del self.managed_conns[(vhost, dpid)]
                        self.scheduler.emergesend(OVSDBBridgeSetup(OVSDBBridgeSetup.DOWN,
                                                                   dpid, 
                                                                   c.ovsdb_systemid, 
                                                                   name, 
                                                                   c, 
                                                                   c.connmark, 
                                                                   c.protocol.vhost,
                                                                   buuid))
    def getconnection(self, datapathid, vhost = ''):
        "Get current connection of datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.managed_conns.get((vhost, datapathid))
    def waitconnection(self, datapathid, timeout = 30, vhost = ''):
        "Wait for a datapath connection"
        for m in self.getconnection(datapathid, vhost):
            yield m
        c = self.apiroutine.retvalue
        if c is None:
            for m in self.apiroutine.waitWithTimeout(timeout, 
                            OVSDBBridgeSetup.createMatcher(
                                    state = OVSDBBridgeSetup.UP,
                                    datapathid = datapathid, vhost = vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath is not connected')
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c
    def getdatapathids(self, vhost = ''):
        "Get All datapath IDs"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.managed_conns.keys() if k[0] == vhost]
    def getalldatapathids(self):
        "Get all datapath IDs from any vhost. Return (vhost, datapathid) pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_conns.keys())
    def getallconnections(self, vhost = ''):
        "Get all connections from vhost. If vhost is None, return all connections from any host"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = list(self.managed_bridges.keys())
        else:
            self.apiroutine.retvalue = list(k for k in self.managed_bridges.keys() if k.protocol.vhost == vhost)
    def getbridges(self, connection):
        "Get all (dpid, name, _uuid) tuple on this connection"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            self.apiroutine.retvalue = [(dpid, name, buuid) for _, dpid, name, buuid in bridges]
        else:
            self.apiroutine.retvalue = None
    def getallbridges(self, vhost = None):
        "Get all (dpid, name, _uuid) tuple for all connections, optionally filtered by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is not None:
            self.apiroutine.retvalue = [(dpid, name, buuid)
                                        for c, bridges in self.managed_bridges.items()
                                        if c.protocol.vhost == vhost
                                        for _, dpid, name, buuid in bridges]
        else:
            self.apiroutine.retvalue = [(dpid, name, buuid)
                                        for c, bridges in self.managed_bridges.items()
                                        for _, dpid, name, buuid in bridges]
    def getbridge(self, connection, name):
        "Get datapath ID on this connection with specified name"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            for _, dpid, n, _ in bridges:
                if n == name:
                    self.apiroutine.retvalue = dpid
                    return
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = None
    def waitbridge(self, connection, name, timeout = 30):
        "Wait for bridge with specified name appears and return the datapath-id"
        bnames = self.bridgenames
        if bnames is not None and name not in bnames:
            raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear: it is not in the selected bridge names')
        for m in self.getbridge(connection, name):
            yield m
        if self.apiroutine.retvalue is None:
            bridge_setup = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.UP,
                                                         None,
                                                         None,
                                                         name,
                                                         connection
                                                         )
            conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN,
                                                                  connection,
                                                                  connection.connmark)
            for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge ' + repr(name) + ' does not appear')
            elif self.apiroutine.matcher is conn_down:
                raise ConnectionResetException('Connection is down before bridge ' + repr(name) + ' appears')
            else:
                self.apiroutine.retvalue = self.apiroutine.event.datapathid
    def getbridgebyuuid(self, connection, uuid):
        "Get datapath ID of bridge on this connection with specified _uuid"
        for m in self._wait_for_sync():
            yield m
        bridges = self.managed_bridges.get(connection)
        if bridges is not None:
            for _, dpid, _, buuid in bridges:
                if buuid == uuid:
                    self.apiroutine.retvalue = dpid
                    return
            self.apiroutine.retvalue = None
        else:
            self.apiroutine.retvalue = None
    def waitbridgebyuuid(self, connection, uuid, timeout = 30):
        "Wait for bridge with specified _uuid appears and return the datapath-id"
        for m in self.getbridgebyuuid(connection, uuid):
            yield m
        if self.apiroutine.retvalue is None:
            bridge_setup = OVSDBBridgeSetup.createMatcher(state = OVSDBBridgeSetup.UP,
                                                         connection = connection,
                                                         bridgeuuid = uuid
                                                         )
            conn_down = JsonRPCConnectionStateEvent.createMatcher(JsonRPCConnectionStateEvent.CONNECTION_DOWN,
                                                                  connection,
                                                                  connection.connmark)
            for m in self.apiroutine.waitWithTimeout(timeout, bridge_setup, conn_down):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge ' + repr(uuid) + ' does not appear')
            elif self.apiroutine.matcher is conn_down:
                raise ConnectionResetException('Connection is down before bridge ' + repr(uuid) + ' appears')
            else:
                self.apiroutine.retvalue = self.apiroutine.event.datapathid
    def getsystemids(self, vhost = ''):
        "Get All system-ids"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.managed_systemids.keys() if k[0] == vhost]
    def getallsystemids(self):
        "Get all system-ids from any vhost. Return (vhost, system-id) pair."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.managed_systemids.keys())
    def getconnectionbysystemid(self, systemid, vhost = ''):
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.managed_systemids.get((vhost, systemid))
    def waitconnectionbysystemid(self, systemid, timeout = 30, vhost = ''):
        "Wait for a connection with specified system-id"
        for m in self.getconnectionbysystemid(systemid, vhost):
            yield m
        c = self.apiroutine.retvalue
        if c is None:
            for m in self.apiroutine.waitWithTimeout(timeout, 
                            OVSDBConnectionSetup.createMatcher(
                                    systemid, None, None, vhost)):
                yield m
            if self.apiroutine.timeout:
                raise ConnectionResetException('Datapath is not connected')
            self.apiroutine.retvalue = self.apiroutine.event.connection
        else:
            self.apiroutine.retvalue = c
    def getconnectionsbyendpoint(self, endpoint, vhost = ''):
        "Get connection by endpoint address (IP, IPv6 or UNIX socket address)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self.endpoint_conns.get((vhost, endpoint))
    def getconnectionsbyendpointname(self, name, vhost = '', timeout = 30):
        "Get connection by endpoint name (Domain name, IP or IPv6 address)"
        # Resolve the name
        if not name:
            endpoint = ''
            for m in self.getconnectionbyendpoint(endpoint, vhost):
                yield m
        else:
            request = (name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG | socket.AI_V4MAPPED)
            # Resolve hostname
            for m in self.apiroutine.waitForSend(ResolveRequestEvent(request)):
                yield m
            for m in self.apiroutine.waitWithTimeout(timeout, ResolveResponseEvent.createMatcher(request)):
                yield m
            if self.apiroutine.timeout:
                # Resolve is only allowed through asynchronous resolver
                #try:
                #    self.addrinfo = socket.getaddrinfo(self.hostname, self.port, socket.AF_UNSPEC, socket.SOCK_DGRAM if self.udp else socket.SOCK_STREAM, socket.IPPROTO_UDP if self.udp else socket.IPPROTO_TCP, socket.AI_ADDRCONFIG|socket.AI_NUMERICHOST)
                #except:
                raise IOError('Resolve hostname timeout: ' + name)
            else:
                if hasattr(self.apiroutine.event, 'error'):
                    raise IOError('Cannot resolve hostname: ' + name)
                resp = self.apiroutine.event.response
                for r in resp:
                    raddr = r[4]
                    if isinstance(raddr, tuple):
                        # Ignore port
                        endpoint = raddr[0]
                    else:
                        # Unix socket? This should not happen, but in case...
                        endpoint = raddr
                    for m in self.getconnectionsbyendpoint(endpoint, vhost):
                        yield m
                    if self.apiroutine.retvalue is not None:
                        break
    def getendpoints(self, vhost = ''):
        "Get all endpoints for vhost"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [k[1] for k in self.endpoint_conns if k[0] == vhost]
    def getallendpoints(self):
        "Get all endpoints from any vhost. Return (vhost, endpoint) pairs."
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = list(self.endpoint_conns.keys())
    def getbridgeinfo(self, datapathid, vhost = ''):
        "Get (bridgename, systemid, bridge_uuid) tuple from bridge datapathid"
        for m in self.getconnection(datapathid, vhost):
            yield m
        if self.apiroutine.retvalue is not None:
            c = self.apiroutine.retvalue
            bridges = self.managed_bridges.get(c)
            if bridges is not None:
                for _, dpid, n, buuid in bridges:
                    if dpid == datapathid:
                        self.apiroutine.retvalue = (n, c.ovsdb_systemid, buuid)
                        return
                self.apiroutine.retvalue = None
            else:
                self.apiroutine.retvalue = None
    def waitbridgeinfo(self, datapathid, timeout = 30, vhost = ''):
        "Wait for bridge with datapathid, and return (bridgename, systemid, bridge_uuid) tuple"
        for m in self.getbridgeinfo(datapathid, vhost):
            yield m
        if self.apiroutine.retvalue is None:
            for m in self.apiroutine.waitWithTimeout(timeout,
                        OVSDBBridgeSetup.createMatcher(
                                    OVSDBBridgeSetup.UP, datapathid,
                                    None, None, None, None,
                                    vhost)):
                yield m
            if self.apiroutine.timeout:
                raise OVSDBBridgeNotAppearException('Bridge 0x%016x does not appear before timeout' % (datapathid,))
            e = self.apiroutine.event
            self.apiroutine.retvalue = (e.name, e.systemid, e.bridgeuuid)
示例#7
0
class OVSDBPortManager(Module):
    '''
    Manage Ports from OVSDB Protocol
    '''
    service = True
    def __init__(self, server):
        Module.__init__(self, server)
        self.apiroutine = RoutineContainer(self.scheduler)
        self.apiroutine.main = self._manage_ports
        self.routines.append(self.apiroutine)
        self.managed_ports = {}
        self.managed_ids = {}
        self.monitor_routines = set()
        self.ports_uuids = {}
        self.wait_portnos = {}
        self.wait_names = {}
        self.wait_ids = {}
        self.bridge_datapathid = {}
        self.createAPI(api(self.getports, self.apiroutine),
                       api(self.getallports, self.apiroutine),
                       api(self.getportbyid, self.apiroutine),
                       api(self.waitportbyid, self.apiroutine),
                       api(self.getportbyname, self.apiroutine),
                       api(self.waitportbyname, self.apiroutine),
                       api(self.getportbyno, self.apiroutine),
                       api(self.waitportbyno, self.apiroutine),
                       api(self.resync, self.apiroutine)
                       )
        self._synchronized = False
    def _get_interface_info(self, connection, protocol, buuid, interface_uuid, port_uuid):
        try:
            method, params = ovsdb.transact('Open_vSwitch',
                                            ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]],
                                                       ["ofport"], [{"ofport":ovsdb.oset()}], False, 5000),
                                            ovsdb.wait('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]],
                                                       ["ifindex"], [{"ifindex":ovsdb.oset()}], False, 5000),
                                            ovsdb.select('Interface', [["_uuid", "==", ovsdb.uuid(interface_uuid)]],
                                                                         ["_uuid", "name", "ifindex", "ofport", "type", "external_ids"]))
            for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                yield m
            r = self.apiroutine.jsonrpc_result[0]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error']))            
            r = self.apiroutine.jsonrpc_result[1]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error']))            
            r = self.apiroutine.jsonrpc_result[2]
            if 'error' in r:
                raise JsonRPCErrorResultException('Error while acquiring interface: ' + repr(r['error']))
            if not r['rows']:
                self.apiroutine.retvalue = []
                return
            r0 = r['rows'][0]
            if r0['ofport'] < 0:
                # Ignore this port because it is in an error state
                self.apiroutine.retvalue = []
                return
            r0['_uuid'] = r0['_uuid'][1]
            r0['ifindex'] = ovsdb.getoptional(r0['ifindex'])
            r0['external_ids'] = ovsdb.getdict(r0['external_ids'])
            if buuid not in self.bridge_datapathid:
                self.apiroutine.retvalue = []
                return
            else:
                datapath_id = self.bridge_datapathid[buuid]
            if 'iface-id' in r0['external_ids']:
                eid = r0['external_ids']['iface-id']
                r0['id'] = eid
                id_ports = self.managed_ids.setdefault((protocol.vhost, eid), [])
                id_ports.append((datapath_id, r0))
            else:
                r0['id'] = None
            self.managed_ports.setdefault((protocol.vhost, datapath_id),[]).append((port_uuid, r0))
            notify = False
            if (protocol.vhost, datapath_id, r0['ofport']) in self.wait_portnos:
                notify = True
                del self.wait_portnos[(protocol.vhost, datapath_id, r0['ofport'])]
            if (protocol.vhost, datapath_id, r0['name']) in self.wait_names:
                notify = True
                del self.wait_names[(protocol.vhost, datapath_id, r0['name'])]
            if (protocol.vhost, r0['id']) in self.wait_ids:
                notify = True
                del self.wait_ids[(protocol.vhost, r0['id'])]
            if notify:
                for m in self.apiroutine.waitForSend(OVSDBPortUpNotification(connection, r0['name'],
                                                                             r0['ofport'], r0['id'],
                                                                             protocol.vhost, datapath_id,
                                                                             port = r0)):
                    yield m
            self.apiroutine.retvalue = [r0]
        except JsonRPCProtocolException:
            self.apiroutine.retvalue = []
    def _remove_interface_id(self, connection, protocol, datapath_id, port):
        eid = port['id']
        eid_list = self.managed_ids.get((protocol.vhost, eid))
        for i in range(0, len(eid_list)):
            if eid_list[i][1]['_uuid'] == port['_uuid']:
                del eid_list[i]
                break
    def _remove_interface(self, connection, protocol, datapath_id, interface_uuid, port_uuid):
        ports = self.managed_ports.get((protocol.vhost, datapath_id))
        r = None
        if ports is not None:
            for i in range(0, len(ports)):
                if ports[i][1]['_uuid'] == interface_uuid:
                    r = ports[i][1]
                    if r['id']:
                        self._remove_interface_id(connection, protocol, datapath_id, r)
                    del ports[i]
                    break
            if not ports:
                del self.managed_ports[(protocol.vhost, datapath_id)]
        return r
    def _remove_all_interface(self, connection, protocol, datapath_id, port_uuid, buuid):
        ports = self.managed_ports.get((protocol.vhost, datapath_id))
        if ports is not None:
            removed_ports = [r for puuid, r in ports if puuid == port_uuid]
            not_removed_ports = [(puuid, r) for puuid, r in ports if puuid != port_uuid]
            ports[:len(not_removed_ports)] = not_removed_ports
            del ports[len(not_removed_ports):]
            for r in removed_ports:
                if r['id']:
                    self._remove_interface_id(connection, protocol, datapath_id, r)
            if not ports:
                del self.managed_ports[(protocol.vhost, datapath_id)]
            return removed_ports
        if port_uuid in self.ports_uuids and self.ports_uuids[port_uuid] == buuid:
            del self.ports_uuids[port_uuid]
        return []
    def _update_interfaces(self, connection, protocol, updateinfo, update = True):
        """
        There are several kinds of updates, they may appear together:
        
        1. New bridge created (or from initial updateinfo). We should add all the interfaces to the list.
        
        2. Bridge removed. Remove all the ports.
        
        3. name and datapath_id may be changed. We will consider this as a new bridge created, and an old
           bridge removed.
        
        4. Bridge ports modification, i.e. add/remove ports.
           a) Normally a port record is created/deleted together. A port record cannot exist without a
              bridge containing it.
               
           b) It is also possible that a port is removed from one bridge and added to another bridge, in
              this case the ports do not appear in the updateinfo
            
        5. Port interfaces modification, i.e. add/remove interfaces. The bridge record may not appear in this
           situation.
           
        We must consider these situations carefully and process them in correct order.
        """
        port_update = updateinfo.get('Port', {})
        bridge_update = updateinfo.get('Bridge', {})
        working_routines = []
        def process_bridge(buuid, uo):
            try:
                nv = uo['new']
                if 'datapath_id' in nv:
                    datapath_id = int(nv['datapath_id'], 16)
                    self.bridge_datapathid[buuid] = datapath_id
                elif buuid in self.bridge_datapathid:
                    datapath_id = self.bridge_datapathid[buuid]
                else:
                    # This should not happen, but just in case...
                    for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitbridge', {'connection': connection,
                                                                    'name': nv['name'],
                                                                    'timeout': 5}):
                        yield m
                    datapath_id = self.apiroutine.retvalue
                if 'ports' in nv:
                    nset = set((p for _,p in ovsdb.getlist(nv['ports'])))
                else:
                    nset = set()
                if 'old' in uo:
                    ov = uo['old']
                    if 'ports' in ov:
                        oset = set((p for _,p in ovsdb.getlist(ov['ports'])))
                    else:
                        # new ports are not really added; it is only sent because datapath_id is modified
                        nset = set()
                        oset = set()
                    if 'datapath_id' in ov:
                        old_datapathid = int(ov['datapath_id'], 16)
                    else:
                        old_datapathid = datapath_id
                else:
                    oset = set()
                    old_datapathid = datapath_id
                # For every deleted port, remove the interfaces with this port _uuid
                remove = []
                add_routine = []
                for puuid in oset - nset:
                    remove += self._remove_all_interface(connection, protocol, old_datapathid, puuid, buuid)
                # For every port not changed, check if the interfaces are modified;
                for puuid in oset.intersection(nset):
                    if puuid in port_update:
                        # The port is modified, there should be an 'old' set and 'new' set
                        pu = port_update[puuid]
                        if 'old' in pu:
                            poset = set((p for _,p in ovsdb.getlist(pu['old']['interfaces'])))
                        else:
                            poset = set()
                        if 'new' in pu:
                            pnset = set((p for _,p in ovsdb.getlist(pu['new']['interfaces'])))
                        else:
                            pnset = set()
                        # Remove old interfaces
                        remove += [r for r in 
                                   (self._remove_interface(connection, protocol, datapath_id, iuuid, puuid)
                                    for iuuid in (poset - pnset)) if r is not None]
                        # Prepare to add new interfaces
                        add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid)
                                        for iuuid in (pnset - poset)]
                # For every port added, add the interfaces
                def add_port_interfaces(puuid):
                    # If the uuid does not appear in update info, we have no choice but to query interfaces with select
                    # we cannot use data from other bridges; the port may be moved from a bridge which is not tracked
                    try:
                        method, params = ovsdb.transact('Open_vSwitch', ovsdb.select('Port',
                                                                                     [["_uuid", "==", ovsdb.uuid(puuid)]],
                                                                                     ["interfaces"]))
                        for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                            yield m
                        r = self.apiroutine.jsonrpc_result[0]
                        if 'error' in r:
                            raise JsonRPCErrorResultException('Error when query interfaces from port ' + repr(puuid) + ': ' + r['error'])
                        if r['rows']:
                            interfaces = ovsdb.getlist(r['rows'][0]['interfaces'])
                            with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, puuid)
                                                                 for _,iuuid in interfaces])) as g:
                                for m in g:
                                    yield m
                            self.apiroutine.retvalue = list(itertools.chain(r[0] for r in self.apiroutine.retvalue))
                        else:
                            self.apiroutine.retvalue = []
                    except JsonRPCProtocolException:
                        self.apiroutine.retvalue = []
                    except ConnectionResetException:
                        self.apiroutine.retvalue = []
                
                for puuid in nset - oset:
                    self.ports_uuids[puuid] = buuid
                    if puuid in port_update and 'new' in port_update[puuid] \
                            and 'old' not in port_update[puuid]:
                        # Add all the interfaces in 'new'
                        interfaces = ovsdb.getlist(port_update[puuid]['new']['interfaces'])
                        add_routine += [self._get_interface_info(connection, protocol, buuid, iuuid, puuid)
                                        for _,iuuid in interfaces]
                    else:
                        add_routine.append(add_port_interfaces(puuid))
                # Execute the add_routine
                try:
                    with closing(self.apiroutine.executeAll(add_routine)) as g:
                        for m in g:
                            yield m
                except:
                    add = []
                    raise
                else:
                    add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue))
                finally:
                    if update:
                        self.scheduler.emergesend(
                                ModuleNotification(self.getServiceName(), 'update',
                                                   datapathid = datapath_id,
                                                   connection = connection,
                                                   vhost = protocol.vhost,
                                                   add = add, remove = remove,
                                                   reason = 'bridgemodify'
                                                            if 'old' in uo
                                                            else 'bridgeup'
                                                   ))
            except JsonRPCProtocolException:
                pass
            except ConnectionResetException:
                pass
            except OVSDBBridgeNotAppearException:
                pass
        ignore_ports = set()
        for buuid, uo in bridge_update.items():
            # Bridge removals are ignored because we process OVSDBBridgeSetup event instead
            if 'old' in uo:
                if 'ports' in uo['old']:
                    oset = set((puuid for _, puuid in ovsdb.getlist(uo['old']['ports'])))
                    ignore_ports.update(oset)
                if 'new' not in uo:
                    if buuid in self.bridge_datapathid:
                        del self.bridge_datapathid[buuid]
            if 'new' in uo:
                # If bridge contains this port is updated, we process the port update totally in bridge,
                # so we ignore it later
                if 'ports' in uo['new']:
                    nset = set((puuid for _, puuid in ovsdb.getlist(uo['new']['ports'])))
                    ignore_ports.update(nset)
                working_routines.append(process_bridge(buuid, uo))                
        def process_port(buuid, port_uuid, interfaces, remove_ids):
            if buuid not in self.bridge_datapathid:
                return
            datapath_id = self.bridge_datapathid[buuid]
            ports = self.managed_ports.get((protocol.vhost, datapath_id))
            remove = []
            if ports is not None:
                remove = [p for _,p in ports if p['_uuid'] in remove_ids]
                not_remove = [(_,p) for _,p in ports if p['_uuid'] not in remove_ids]
                ports[:len(not_remove)] = not_remove
                del ports[len(not_remove):]
            if interfaces:
                try:
                    with closing(self.apiroutine.executeAll([self._get_interface_info(connection, protocol, buuid, iuuid, port_uuid)
                                                         for iuuid in interfaces])) as g:
                        for m in g:
                            yield m
                    add = list(itertools.chain((r[0] for r in self.apiroutine.retvalue if r[0])))
                except Exception:
                    self._logger.warning("Cannot get new port information", exc_info = True)
                    add = []
            else:
                add = []
            if update:
                for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'update', datapathid = datapath_id,
                                                                                                          connection = connection,
                                                                                                          vhost = protocol.vhost,
                                                                                                          add = add, remove = remove,
                                                                                                          reason = 'bridgemodify'
                                                                                                          )):
                    yield m
        for puuid, po in port_update.items():
            if puuid not in ignore_ports:
                bridge_id = self.ports_uuids.get(puuid)
                if bridge_id is not None:
                    datapath_id = self.bridge_datapathid[bridge_id]
                    if datapath_id is not None:
                        # This port is modified
                        if 'new' in po:
                            nset = set((iuuid for _, iuuid in ovsdb.getlist(po['new']['interfaces'])))
                        else:
                            nset = set()                    
                        if 'old' in po:
                            oset = set((iuuid for _, iuuid in ovsdb.getlist(po['old']['interfaces'])))
                        else:
                            oset = set()
                        working_routines.append(process_port(bridge_id, puuid, nset - oset, oset - nset))
        if update:
            for r in working_routines:
                self.apiroutine.subroutine(r)
        else:
            try:
                with closing(self.apiroutine.executeAll(working_routines, None, ())) as g:
                    for m in g:
                        yield m
            finally:
                self.scheduler.emergesend(OVSDBConnectionPortsSynchronized(connection))
    def _get_ports(self, connection, protocol):
        try:
            try:
                method, params = ovsdb.monitor('Open_vSwitch', 'ovsdb_port_manager_interfaces_monitor', {
                                                    'Bridge':[ovsdb.monitor_request(["name", "datapath_id", "ports"])],
                                                    'Port':[ovsdb.monitor_request(["interfaces"])]
                                                })
                for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                    yield m
                if 'error' in self.apiroutine.jsonrpc_result:
                    # The monitor is already set, cancel it first
                    method2, params2 = ovsdb.monitor_cancel('ovsdb_port_manager_interfaces_monitor')
                    for m in protocol.querywithreply(method2, params2, connection, self.apiroutine, False):
                        yield m
                    for m in protocol.querywithreply(method, params, connection, self.apiroutine):
                        yield m
                    if 'error' in self.apiroutine.jsonrpc_result:
                        raise JsonRPCErrorResultException('OVSDB request failed: ' + repr(self.apiroutine.jsonrpc_result))
                r = self.apiroutine.jsonrpc_result
            except:
                for m in self.apiroutine.waitForSend(OVSDBConnectionPortsSynchronized(connection)):
                    yield m
                raise
            # This is the initial state, it should contains all the ids of ports and interfaces
            self.apiroutine.subroutine(self._update_interfaces(connection, protocol, r, False))
            update_matcher = JsonRPCNotificationEvent.createMatcher('update', connection, connection.connmark,
                                                                    _ismatch = lambda x: x.params[0] == 'ovsdb_port_manager_interfaces_monitor')
            conn_state = protocol.statematcher(connection)
            while True:
                yield (update_matcher, conn_state)
                if self.apiroutine.matcher is conn_state:
                    break
                else:
                    self.apiroutine.subroutine(self._update_interfaces(connection, protocol, self.apiroutine.event.params[1], True))
        except JsonRPCProtocolException:
            pass
        finally:
            if self.apiroutine.currentroutine in self.monitor_routines:
                self.monitor_routines.remove(self.apiroutine.currentroutine)
    def _get_existing_ports(self):
        for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getallconnections', {'vhost':None}):
            yield m
        matchers = []
        for c in self.apiroutine.retvalue:
            self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(c, c.protocol)))
            matchers.append(OVSDBConnectionPortsSynchronized.createMatcher(c))
        for m in self.apiroutine.waitForAll(*matchers):
            yield m
        self._synchronized = True
        for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(), 'synchronized')):
            yield m
    def _wait_for_sync(self):
        if not self._synchronized:
            yield (ModuleNotification.createMatcher(self.getServiceName(), 'synchronized'),)
    def _manage_ports(self):
        try:
            self.apiroutine.subroutine(self._get_existing_ports())
            connsetup = OVSDBConnectionSetup.createMatcher()
            bridgedown = OVSDBBridgeSetup.createMatcher(OVSDBBridgeSetup.DOWN)
            while True:
                yield (connsetup, bridgedown)
                e = self.apiroutine.event
                if self.apiroutine.matcher is connsetup:
                    self.monitor_routines.add(self.apiroutine.subroutine(self._get_ports(e.connection, e.connection.protocol)))
                else:
                    # Remove ports of the bridge
                    ports =  self.managed_ports.get((e.vhost, e.datapathid))
                    if ports is not None:
                        ports_original = ports
                        ports = [p for _,p in ports]
                        for p in ports:
                            if p['id']:
                                self._remove_interface_id(e.connection,
                                                          e.connection.protocol, e.datapathid, p)
                        newdpid = getattr(e, 'new_datapath_id', None)
                        buuid = e.bridgeuuid
                        if newdpid is not None:
                            # This bridge changes its datapath id
                            if buuid in self.bridge_datapathid and self.bridge_datapathid[buuid] == e.datapathid:
                                self.bridge_datapathid[buuid] = newdpid
                            def re_add_interfaces():
                                with closing(self.apiroutine.executeAll(
                                    [self._get_interface_info(e.connection, e.connection.protocol, buuid,
                                                              r['_uuid'], puuid)
                                     for puuid, r in ports_original])) as g:
                                    for m in g:
                                        yield m
                                add = list(itertools.chain(r[0] for r in self.apiroutine.retvalue))
                                for m in self.apiroutine.waitForSend(ModuleNotification(self.getServiceName(),
                                                  'update', datapathid = e.datapathid,
                                                  connection = e.connection,
                                                  vhost = e.vhost,
                                                  add = add, remove = [],
                                                  reason = 'bridgeup'
                                                  )):
                                    yield m
                            self.apiroutine.subroutine(re_add_interfaces())
                        else:
                            # The ports are removed
                            for puuid, _ in ports_original:
                                if puuid in self.ports_uuids[puuid] and self.ports_uuids[puuid] == buuid:
                                    del self.ports_uuids[puuid]
                        del self.managed_ports[(e.vhost, e.datapathid)]
                        self.scheduler.emergesend(ModuleNotification(self.getServiceName(),
                                                  'update', datapathid = e.datapathid,
                                                  connection = e.connection,
                                                  vhost = e.vhost,
                                                  add = [], remove = ports,
                                                  reason = 'bridgedown'
                                                  ))                            
        finally:
            for r in list(self.monitor_routines):
                r.close()
            self.scheduler.emergesend(ModuleNotification(self.getServiceName(), 'unsynchronized'))
    def getports(self, datapathid, vhost = ''):
        "Return all ports of a specifed datapath"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = [p for _,p in self.managed_ports.get((vhost, datapathid), [])]
    def getallports(self, vhost = None):
        "Return all (datapathid, port, vhost) tuples, optionally filterd by vhost"
        for m in self._wait_for_sync():
            yield m
        if vhost is None:
            self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() for _,p in v]
        else:
            self.apiroutine.retvalue = [(dpid, p, vh) for (vh, dpid),v in self.managed_ports.items() if vh == vhost for _,p in v]
    def getportbyno(self, datapathid, portno, vhost = ''):
        "Return port with specified portno"
        portno &= 0xffff
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyno(datapathid, portno, vhost)
    def _getportbyno(self, datapathid, portno, vhost = ''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for _, p in ports:
                if p['ofport'] == portno:
                    return p
            return None
    def waitportbyno(self, datapathid, portno, timeout = 30, vhost = ''):
        "Wait for port with specified portno"
        portno &= 0xffff
        for m in self._wait_for_sync():
            yield m
        def waitinner():
            p = self._getportbyno(datapathid, portno, vhost)
            if p is not None:
                self.apiroutine.retvalue = p
            else:
                try:
                    self.wait_portnos[(vhost, datapathid, portno)] = \
                            self.wait_portnos.get((vhost, datapathid, portno),0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(None, None, portno, None, vhost, datapathid),)
                except:
                    v = self.wait_portnos.get((vhost, datapathid, portno))
                    if v is not None:
                        if v <= 1:
                            del self.wait_portnos[(vhost, datapathid, portno)]
                        else:
                            self.wait_portnos[(vhost, datapathid, portno)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = self.apiroutine.event.port
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException('Port ' + repr(portno) + ' does not appear before timeout')
    def getportbyname(self, datapathid, name, vhost = ''):
        "Return port with specified name"
        if isinstance(name, bytes):
            name = name.decode('utf-8')
        for m in self._wait_for_sync():
            yield m
        self.apiroutine.retvalue = self._getportbyname(datapathid, name, vhost)
    def _getportbyname(self, datapathid, name, vhost = ''):
        ports = self.managed_ports.get((vhost, datapathid))
        if ports is None:
            return None
        else:
            for _, p in ports:
                if p['name'] == name:
                    return p
            return None
    def waitportbyname(self, datapathid, name, timeout = 30, vhost = ''):
        "Wait for port with specified name"
        for m in self._wait_for_sync():
            yield m
        def waitinner():
            p = self._getportbyname(datapathid, name, vhost)
            if p is not None:
                self.apiroutine.retvalue = p
            else:
                try:
                    self.wait_names[(vhost, datapathid, name)] = \
                            self.wait_portnos.get((vhost, datapathid, name) ,0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(None, name, None, None, vhost, datapathid),)
                except:
                    v = self.wait_names.get((vhost, datapathid, name))
                    if v is not None:
                        if v <= 1:
                            del self.wait_names[(vhost, datapathid, name)]
                        else:
                            self.wait_names[(vhost, datapathid, name)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = self.apiroutine.event.port
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException('Port ' + repr(name) + ' does not appear before timeout')
    def getportbyid(self, id, vhost = ''):
        "Return port with the specified id. The return value is a pair: (datapath_id, port)"
        for m in self._wait_for_sync():
            yield m
        self.apiroutine = self._getportbyid(id, vhost)
    def _getportbyid(self, id, vhost = ''):
        ports = self.managed_ids.get((vhost, id))
        if ports:
            return ports[0]
        else:
            return None
    def waitportbyid(self, id, timeout = 30, vhost = ''):
        "Wait for port with the specified id. The return value is a pair (datapath_id, port)"
        for m in self._wait_for_sync():
            yield m
        def waitinner():
            p = self._getportbyid(id, vhost)
            if p is None:
                try:
                    self.wait_ids[(vhost, id)] = self.wait_ids.get((vhost, id), 0) + 1
                    yield (OVSDBPortUpNotification.createMatcher(None, None, None, id, vhost),)
                except:
                    v = self.wait_ids.get((vhost, id))
                    if v is not None:
                        if v <= 1:
                            del self.wait_ids[(vhost, id)]
                        else:
                            self.wait_ids[(vhost, id)] = v - 1
                    raise
                else:
                    self.apiroutine.retvalue = (self.apiroutine.event.datapathid,
                                                self.apiroutine.event.port)
            else:
                self.apiroutine.retvalue = p
        for m in self.apiroutine.executeWithTimeout(timeout, waitinner()):
            yield m
        if self.apiroutine.timeout:
            raise OVSDBPortNotAppearException('Port ' + repr(id) + ' does not appear before timeout')
    def resync(self, datapathid, vhost = ''):
        '''
        Resync with current ports
        '''
        # Sometimes when the OVSDB connection is very busy, monitor message may be dropped.
        # We must deal with this and recover from it
        # Save current manged_ports
        if (vhost, datapathid) not in self.managed_ports:
            self.apiroutine.retvalue = None
            return
        else:
            for m in callAPI(self.apiroutine, 'ovsdbmanager', 'getconnection', {'datapathid': datapathid, 'vhost':vhost}):
                yield m
            c = self.apiroutine.retvalue
            if c is not None:
                # For now, we restart the connection...
                for m in c.reconnect(False):
                    yield m
                for m in self.apiroutine.waitWithTimeout(0.1):
                    yield m
                for m in callAPI(self.apiroutine, 'ovsdbmanager', 'waitconnection', {'datapathid': datapathid,
                                                                                     'vhost': vhost}):
                    yield m
        self.apiroutine.retvalue = None