コード例 #1
0
ファイル: test_addrpool.py プロジェクト: 0x90/pyroute2
    def test_free_reverse_fail(self):

        ap = AddrPool(minaddr=1, maxaddr=1024, reverse=True)
        try:
            ap.free(0)
        except KeyError:
            pass
コード例 #2
0
ファイル: test_addrpool.py プロジェクト: welterde/pyroute2
    def test_free_reverse_fail(self):

        ap = AddrPool(minaddr=1, maxaddr=1024, reverse=True)
        try:
            ap.free(0)
        except KeyError:
            pass
コード例 #3
0
ファイル: test_addrpool.py プロジェクト: nazarewk/pyroute2
    def test_setaddr_allocated(self):

        ap = AddrPool()
        f = ap.alloc()
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
        ap.setaddr(f + 1, 'allocated')
        base, bit, is_allocated = ap.locate(f + 1)
        assert is_allocated
        assert ap.allocated == 2
        ap.free(f + 1)
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
コード例 #4
0
    def test_setaddr_allocated(self):

        ap = AddrPool()
        f = ap.alloc()
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
        ap.setaddr(f + 1, 'allocated')
        base, bit, is_allocated = ap.locate(f + 1)
        assert is_allocated
        assert ap.allocated == 2
        ap.free(f + 1)
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
コード例 #5
0
ファイル: test_addrpool.py プロジェクト: nazarewk/pyroute2
    def test_setaddr_free(self):

        ap = AddrPool()
        f = ap.alloc()
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
        ap.setaddr(f + 1, 'free')
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
        ap.setaddr(f, 'free')
        base, bit, is_allocated = ap.locate(f)
        assert not is_allocated
        assert ap.allocated == 0
        try:
            ap.free(f)
        except KeyError:
            pass
コード例 #6
0
    def test_setaddr_free(self):

        ap = AddrPool()
        f = ap.alloc()
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
        ap.setaddr(f + 1, 'free')
        base, bit, is_allocated = ap.locate(f + 1)
        assert not is_allocated
        assert ap.allocated == 1
        ap.setaddr(f, 'free')
        base, bit, is_allocated = ap.locate(f)
        assert not is_allocated
        assert ap.allocated == 0
        try:
            ap.free(f)
        except KeyError:
            pass
コード例 #7
0
class NetlinkMixin(object):
    '''
    Generic netlink socket
    '''
    def __init__(self,
                 family=NETLINK_GENERIC,
                 port=None,
                 pid=None,
                 fileno=None,
                 sndbuf=1048576,
                 rcvbuf=1048576,
                 all_ns=False,
                 async_qsize=None,
                 nlm_generator=None):
        #
        # That's a trick. Python 2 is not able to construct
        # sockets from an open FD.
        #
        # So raise an exception, if the major version is < 3
        # and fileno is not None.
        #
        # Do NOT use fileno in a core pyroute2 functionality,
        # since the core should be both Python 2 and 3
        # compatible.
        #
        super(NetlinkMixin, self).__init__()
        if fileno is not None and sys.version_info[0] < 3:
            raise NotImplementedError('fileno parameter is not supported '
                                      'on Python < 3.2')

        # 8<-----------------------------------------
        self.config = {
            'family': family,
            'port': port,
            'pid': pid,
            'fileno': fileno,
            'sndbuf': sndbuf,
            'rcvbuf': rcvbuf,
            'all_ns': all_ns,
            'async_qsize': async_qsize,
            'nlm_generator': nlm_generator
        }
        # 8<-----------------------------------------
        self.addr_pool = AddrPool(minaddr=0x000000ff, maxaddr=0x0000ffff)
        self.epid = None
        self.port = 0
        self.fixed = True
        self.family = family
        self._fileno = fileno
        self._sndbuf = sndbuf
        self._rcvbuf = rcvbuf
        self.backlog = {0: []}
        self.callbacks = []  # [(predicate, callback, args), ...]
        self.pthread = None
        self.closed = False
        self.uname = config.uname
        self.capabilities = {
            'create_bridge': config.kernel > [3, 2, 0],
            'create_bond': config.kernel > [3, 2, 0],
            'create_dummy': True,
            'provide_master': config.kernel[0] > 2
        }
        self.backlog_lock = threading.Lock()
        self.read_lock = threading.Lock()
        self.sys_lock = threading.RLock()
        self.change_master = threading.Event()
        self.lock = LockFactory()
        self._sock = None
        self._ctrl_read, self._ctrl_write = os.pipe()
        if async_qsize is None:
            async_qsize = config.async_qsize
        self.async_qsize = async_qsize
        if nlm_generator is None:
            nlm_generator = config.nlm_generator
        self.nlm_generator = nlm_generator
        self.buffer_queue = Queue(maxsize=async_qsize)
        self.qsize = 0
        self.log = []
        self.get_timeout = 30
        self.get_timeout_exception = None
        self.all_ns = all_ns
        if pid is None:
            self.pid = os.getpid() & 0x3fffff
            self.port = port
            self.fixed = self.port is not None
        elif pid == 0:
            self.pid = os.getpid()
        else:
            self.pid = pid
        # 8<-----------------------------------------
        self.groups = 0
        self.marshal = Marshal()
        # 8<-----------------------------------------
        if not nlm_generator:

            def nlm_request(*argv, **kwarg):
                return tuple(self._genlm_request(*argv, **kwarg))

            def get(*argv, **kwarg):
                return tuple(self._genlm_get(*argv, **kwarg))

            self._genlm_request = self.nlm_request
            self._genlm_get = self.get

            self.nlm_request = nlm_request
            self.get = get

        # Set defaults
        self.post_init()

    def post_init(self):
        pass

    def clone(self):
        return type(self)(**self.config)

    def close(self, code=errno.ECONNRESET):
        if code > 0 and self.pthread:
            self.buffer_queue.put(
                struct.pack('IHHQIQQ', 28, 2, 0, 0, code, 0, 0))
        try:
            os.close(self._ctrl_write)
            os.close(self._ctrl_read)
        except OSError:
            # ignore the case when it is closed already
            pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def release(self):
        log.warning("The `release()` call is deprecated")
        log.warning("Use `close()` instead")
        self.close()

    def register_callback(self, callback, predicate=lambda x: True, args=None):
        '''
        Register a callback to run on a message arrival.

        Callback is the function that will be called with the
        message as the first argument. Predicate is the optional
        callable object, that returns True or False. Upon True,
        the callback will be called. Upon False it will not.
        Args is a list or tuple of arguments.

        Simplest example, assume ipr is the IPRoute() instance::

            # create a simplest callback that will print messages
            def cb(msg):
                print(msg)

            # register callback for any message:
            ipr.register_callback(cb)

        More complex example, with filtering::

            # Set object's attribute after the message key
            def cb(msg, obj):
                obj.some_attr = msg["some key"]

            # Register the callback only for the loopback device, index 1:
            ipr.register_callback(cb,
                                  lambda x: x.get('index', None) == 1,
                                  (self, ))

        Please note: you do **not** need to register the default 0 queue
        to invoke callbacks on broadcast messages. Callbacks are
        iterated **before** messages get enqueued.
        '''
        if args is None:
            args = []
        self.callbacks.append((predicate, callback, args))

    def unregister_callback(self, callback):
        '''
        Remove the first reference to the function from the callback
        register
        '''
        cb = tuple(self.callbacks)
        for cr in cb:
            if cr[1] == callback:
                self.callbacks.pop(cb.index(cr))
                return

    def register_policy(self, policy, msg_class=None):
        '''
        Register netlink encoding/decoding policy. Can
        be specified in two ways:
        `nlsocket.register_policy(MSG_ID, msg_class)`
        to register one particular rule, or
        `nlsocket.register_policy({MSG_ID1: msg_class})`
        to register several rules at once.
        E.g.::

            policy = {RTM_NEWLINK: ifinfmsg,
                      RTM_DELLINK: ifinfmsg,
                      RTM_NEWADDR: ifaddrmsg,
                      RTM_DELADDR: ifaddrmsg}
            nlsocket.register_policy(policy)

        One can call `register_policy()` as many times,
        as one want to -- it will just extend the current
        policy scheme, not replace it.
        '''
        if isinstance(policy, int) and msg_class is not None:
            policy = {policy: msg_class}

        assert isinstance(policy, dict)
        for key in policy:
            self.marshal.msg_map[key] = policy[key]

        return self.marshal.msg_map

    def unregister_policy(self, policy):
        '''
        Unregister policy. Policy can be:

            - int -- then it will just remove one policy
            - list or tuple of ints -- remove all given
            - dict -- remove policies by keys from dict

        In the last case the routine will ignore dict values,
        it is implemented so just to make it compatible with
        `get_policy_map()` return value.
        '''
        if isinstance(policy, int):
            policy = [policy]
        elif isinstance(policy, dict):
            policy = list(policy)

        assert isinstance(policy, (tuple, list, set))

        for key in policy:
            del self.marshal.msg_map[key]

        return self.marshal.msg_map

    def get_policy_map(self, policy=None):
        '''
        Return policy for a given message type or for all
        message types. Policy parameter can be either int,
        or a list of ints. Always return dictionary.
        '''
        if policy is None:
            return self.marshal.msg_map

        if isinstance(policy, int):
            policy = [policy]

        assert isinstance(policy, (list, tuple, set))

        ret = {}
        for key in policy:
            ret[key] = self.marshal.msg_map[key]

        return ret

    def sendto(self, *argv, **kwarg):
        return self._sendto(*argv, **kwarg)

    def recv(self, *argv, **kwarg):
        return self._recv(*argv, **kwarg)

    def recv_into(self, *argv, **kwarg):
        return self._recv_into(*argv, **kwarg)

    def recv_ft(self, *argv, **kwarg):
        return self._recv(*argv, **kwarg)

    def async_recv(self):
        poll = select.poll()
        poll.register(self._sock, select.POLLIN | select.POLLPRI)
        poll.register(self._ctrl_read, select.POLLIN | select.POLLPRI)
        sockfd = self._sock.fileno()
        while True:
            events = poll.poll()
            for (fd, event) in events:
                if fd == sockfd:
                    try:
                        data = bytearray(64000)
                        self._sock.recv_into(data, 64000)
                        self.buffer_queue.put_nowait(data)
                    except Exception as e:
                        self.buffer_queue.put(e)
                        return
                else:
                    return

    def put(self,
            msg,
            msg_type,
            msg_flags=NLM_F_REQUEST,
            addr=(0, 0),
            msg_seq=0,
            msg_pid=None):
        '''
        Construct a message from a dictionary and send it to
        the socket. Parameters:

            - msg -- the message in the dictionary format
            - msg_type -- the message type
            - msg_flags -- the message flags to use in the request
            - addr -- `sendto()` addr, default `(0, 0)`
            - msg_seq -- sequence number to use
            - msg_pid -- pid to use, if `None` -- use os.getpid()

        Example::

            s = IPRSocket()
            s.bind()
            s.put({'index': 1}, RTM_GETLINK)
            s.get()
            s.close()

        Please notice, that the return value of `s.get()` can be
        not the result of `s.put()`, but any broadcast message.
        To fix that, use `msg_seq` -- the response must contain the
        same `msg['header']['sequence_number']` value.
        '''
        if msg_seq != 0:
            self.lock[msg_seq].acquire()
        try:
            if msg_seq not in self.backlog:
                self.backlog[msg_seq] = []
            if not isinstance(msg, nlmsg):
                msg_class = self.marshal.msg_map[msg_type]
                msg = msg_class(msg)
            if msg_pid is None:
                msg_pid = self.epid or os.getpid()
            msg['header']['type'] = msg_type
            msg['header']['flags'] = msg_flags
            msg['header']['sequence_number'] = msg_seq
            msg['header']['pid'] = msg_pid
            self.sendto_gate(msg, addr)
        except:
            raise
        finally:
            if msg_seq != 0:
                self.lock[msg_seq].release()

    def sendto_gate(self, msg, addr):
        raise NotImplementedError()

    def get(self,
            bufsize=DEFAULT_RCVBUF,
            msg_seq=0,
            terminate=None,
            callback=None):
        '''
        Get parsed messages list. If `msg_seq` is given, return
        only messages with that `msg['header']['sequence_number']`,
        saving all other messages into `self.backlog`.

        The routine is thread-safe.

        The `bufsize` parameter can be:

            - -1: bufsize will be calculated from the first 4 bytes of
                the network data
            - 0: bufsize will be calculated from SO_RCVBUF sockopt
            - int >= 0: just a bufsize
        '''
        ctime = time.time()

        with self.lock[msg_seq]:
            if bufsize == -1:
                # get bufsize from the network data
                bufsize = struct.unpack("I", self.recv(4, MSG_PEEK))[0]
            elif bufsize == 0:
                # get bufsize from SO_RCVBUF
                bufsize = self.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2

            tmsg = None
            enough = False
            backlog_acquired = False
            try:
                while not enough:
                    # 8<-----------------------------------------------------------
                    #
                    # This stage changes the backlog, so use mutex to
                    # prevent side changes
                    self.backlog_lock.acquire()
                    backlog_acquired = True
                    ##
                    # Stage 1. BEGIN
                    #
                    # 8<-----------------------------------------------------------
                    #
                    # Check backlog and return already collected
                    # messages.
                    #
                    if msg_seq == 0 and self.backlog[0]:
                        # Zero queue.
                        #
                        # Load the backlog, if there is valid
                        # content in it
                        for msg in self.backlog[0]:
                            yield msg
                        self.backlog[0] = []
                        # And just exit
                        break
                    elif msg_seq != 0 and len(self.backlog.get(msg_seq, [])):
                        # Any other msg_seq.
                        #
                        # Collect messages up to the terminator.
                        # Terminator conditions:
                        #  * NLMSG_ERROR != 0
                        #  * NLMSG_DONE
                        #  * terminate() function (if defined)
                        #  * not NLM_F_MULTI
                        #
                        # Please note, that if terminator not occured,
                        # more `recv()` rounds CAN be required.
                        for msg in tuple(self.backlog[msg_seq]):

                            # Drop the message from the backlog, if any
                            self.backlog[msg_seq].remove(msg)

                            # If there is an error, raise exception
                            if msg['header'].get('error', None) is not None:
                                self.backlog[0].extend(self.backlog[msg_seq])
                                del self.backlog[msg_seq]
                                # The loop is done
                                raise msg['header']['error']

                            # If it is the terminator message, say "enough"
                            # and requeue all the rest into Zero queue
                            if terminate is not None:
                                tmsg = terminate(msg)
                                if isinstance(tmsg, nlmsg):
                                    yield msg
                            if (msg['header']['type'] == NLMSG_DONE) or tmsg:
                                # The loop is done
                                enough = True

                            # If it is just a normal message, append it to
                            # the response
                            if not enough:
                                # finish the loop on single messages
                                if not msg['header']['flags'] & NLM_F_MULTI:
                                    enough = True
                                yield msg

                            # Enough is enough, requeue the rest and delete
                            # our backlog
                            if enough:
                                self.backlog[0].extend(self.backlog[msg_seq])
                                del self.backlog[msg_seq]
                                break

                        # Next iteration
                        self.backlog_lock.release()
                        backlog_acquired = False
                    else:
                        # Stage 1. END
                        #
                        # 8<-------------------------------------------------------
                        #
                        # Stage 2. BEGIN
                        #
                        # 8<-------------------------------------------------------
                        #
                        # Receive the data from the socket and put the messages
                        # into the backlog
                        #
                        self.backlog_lock.release()
                        backlog_acquired = False
                        ##
                        #
                        # Control the timeout. We should not be within the
                        # function more than TIMEOUT seconds. All the locks
                        # MUST be released here.
                        #
                        if (msg_seq != 0) and \
                                (time.time() - ctime > self.get_timeout):
                            # requeue already received for that msg_seq
                            self.backlog[0].extend(self.backlog[msg_seq])
                            del self.backlog[msg_seq]
                            # throw an exception
                            if self.get_timeout_exception:
                                raise self.get_timeout_exception()
                            else:
                                return
                        #
                        if self.read_lock.acquire(False):
                            try:
                                self.change_master.clear()
                                # If the socket is free to read from, occupy
                                # it and wait for the data
                                #
                                # This is a time consuming process, so all the
                                # locks, except the read lock must be released
                                data = self.recv_ft(bufsize)
                                # Parse data
                                msgs = self.marshal.parse(
                                    data, msg_seq, callback)
                                # Reset ctime -- timeout should be measured
                                # for every turn separately
                                ctime = time.time()
                                #
                                current = self.buffer_queue.qsize()
                                delta = current - self.qsize
                                delay = 0
                                if delta > 10:
                                    delay = min(
                                        3, max(0.01,
                                               float(current) / 60000))
                                    message = ("Packet burst: "
                                               "delta=%s qsize=%s delay=%s" %
                                               (delta, current, delay))
                                    if delay < 1:
                                        log.debug(message)
                                    else:
                                        log.warning(message)
                                    time.sleep(delay)
                                self.qsize = current

                                # We've got the data, lock the backlog again
                                with self.backlog_lock:
                                    for msg in msgs:
                                        msg['header']['stats'] = Stats(
                                            current, delta, delay)
                                        seq = msg['header']['sequence_number']
                                        if seq not in self.backlog:
                                            if msg['header']['type'] == \
                                                    NLMSG_ERROR:
                                                # Drop orphaned NLMSG_ERROR
                                                # messages
                                                continue
                                            seq = 0
                                        # 8<-----------------------------------
                                        # Callbacks section
                                        for cr in self.callbacks:
                                            try:
                                                if cr[0](msg):
                                                    cr[1](msg, *cr[2])
                                            except:
                                                # FIXME
                                                #
                                                # Usually such code formatting
                                                # means that the method should
                                                # be refactored to avoid such
                                                # indentation.
                                                #
                                                # Plz do something with it.
                                                #
                                                lw = log.warning
                                                lw("Callback fail: %s" % (cr))
                                                lw(traceback.format_exc())
                                        # 8<-----------------------------------
                                        self.backlog[seq].append(msg)

                                # Now wake up other threads
                                self.change_master.set()
                            finally:
                                # Finally, release the read lock: all data
                                # processed
                                self.read_lock.release()
                        else:
                            # If the socket is occupied and there is still no
                            # data for us, wait for the next master change or
                            # for a timeout
                            self.change_master.wait(1)
                        # 8<-------------------------------------------------------
                        #
                        # Stage 2. END
                        #
                        # 8<-------------------------------------------------------
            finally:
                if backlog_acquired:
                    self.backlog_lock.release()

    def nlm_request(self,
                    msg,
                    msg_type,
                    msg_flags=NLM_F_REQUEST | NLM_F_DUMP,
                    terminate=None,
                    callback=None):

        msg_seq = self.addr_pool.alloc()
        with self.lock[msg_seq]:
            retry_count = 0
            while True:
                try:
                    self.put(msg, msg_type, msg_flags, msg_seq=msg_seq)
                    for msg in self.get(msg_seq=msg_seq,
                                        terminate=terminate,
                                        callback=callback):
                        yield msg
                    break
                except NetlinkError as e:
                    if e.code != 16:
                        raise
                    if retry_count >= 30:
                        raise
                    print('Error 16, retry {}.'.format(retry_count))
                    time.sleep(0.3)
                    retry_count += 1
                    continue
                except Exception:
                    raise
                finally:
                    # Ban this msg_seq for 0xff rounds
                    #
                    # It's a long story. Modern kernels for RTM_SET.*
                    # operations always return NLMSG_ERROR(0) == success,
                    # even not setting NLM_F_MULTY flag on other response
                    # messages and thus w/o any NLMSG_DONE. So, how to detect
                    # the response end? One can not rely on NLMSG_ERROR on
                    # old kernels, but we have to support them too. Ty, we
                    # just ban msg_seq for several rounds, and NLMSG_ERROR,
                    # being received, will become orphaned and just dropped.
                    #
                    # Hack, but true.
                    self.addr_pool.free(msg_seq, ban=0xff)
コード例 #8
0
ファイル: nlsocket.py プロジェクト: kinhvan017/pyroute2
class NetlinkMixin(object):
    '''
    Generic netlink socket
    '''

    def __init__(self,
                 family=NETLINK_GENERIC,
                 port=None,
                 pid=None,
                 fileno=None):
        #
        # That's a trick. Python 2 is not able to construct
        # sockets from an open FD.
        #
        # So raise an exception, if the major version is < 3
        # and fileno is not None.
        #
        # Do NOT use fileno in a core pyroute2 functionality,
        # since the core should be both Python 2 and 3
        # compatible.
        #
        super(NetlinkMixin, self).__init__()
        if fileno is not None and sys.version_info[0] < 3:
            raise NotImplementedError('fileno parameter is not supported '
                                      'on Python < 3.2')

        # 8<-----------------------------------------
        self.addr_pool = AddrPool(minaddr=0x000000ff, maxaddr=0x0000ffff)
        self.epid = None
        self.port = 0
        self.fixed = True
        self.family = family
        self._fileno = fileno
        self.backlog = {0: []}
        self.callbacks = []     # [(predicate, callback, args), ...]
        self.pthread = None
        self.closed = False
        self.capabilities = {'create_bridge': True,
                             'create_bond': True,
                             'create_dummy': True,
                             'provide_master': config.kernel[0] > 2}
        self.backlog_lock = threading.Lock()
        self.read_lock = threading.Lock()
        self.change_master = threading.Event()
        self.lock = LockFactory()
        self._sock = None
        self._ctrl_read, self._ctrl_write = os.pipe()
        self.buffer_queue = Queue()
        self.qsize = 0
        self.log = []
        self.get_timeout = 30
        self.get_timeout_exception = None
        if pid is None:
            self.pid = os.getpid() & 0x3fffff
            self.port = port
            self.fixed = self.port is not None
        elif pid == 0:
            self.pid = os.getpid()
        else:
            self.pid = pid
        # 8<-----------------------------------------
        self.groups = 0
        self.marshal = Marshal()
        # 8<-----------------------------------------
        # Set defaults
        self.post_init()

    def clone(self):
        return type(self)(family=self.family)

    def close(self):
        try:
            os.close(self._ctrl_write)
            os.close(self._ctrl_read)
        except OSError:
            # ignore the case when it is closed already
            pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def release(self):
        log.warning("The `release()` call is deprecated")
        log.warning("Use `close()` instead")
        self.close()

    def register_callback(self, callback,
                          predicate=lambda x: True, args=None):
        '''
        Register a callback to run on a message arrival.

        Callback is the function that will be called with the
        message as the first argument. Predicate is the optional
        callable object, that returns True or False. Upon True,
        the callback will be called. Upon False it will not.
        Args is a list or tuple of arguments.

        Simplest example, assume ipr is the IPRoute() instance::

            # create a simplest callback that will print messages
            def cb(msg):
                print(msg)

            # register callback for any message:
            ipr.register_callback(cb)

        More complex example, with filtering::

            # Set object's attribute after the message key
            def cb(msg, obj):
                obj.some_attr = msg["some key"]

            # Register the callback only for the loopback device, index 1:
            ipr.register_callback(cb,
                                  lambda x: x.get('index', None) == 1,
                                  (self, ))

        Please note: you do **not** need to register the default 0 queue
        to invoke callbacks on broadcast messages. Callbacks are
        iterated **before** messages get enqueued.
        '''
        if args is None:
            args = []
        self.callbacks.append((predicate, callback, args))

    def unregister_callback(self, callback):
        '''
        Remove the first reference to the function from the callback
        register
        '''
        cb = tuple(self.callbacks)
        for cr in cb:
            if cr[1] == callback:
                self.callbacks.pop(cb.index(cr))
                return

    def register_policy(self, policy, msg_class=None):
        '''
        Register netlink encoding/decoding policy. Can
        be specified in two ways:
        `nlsocket.register_policy(MSG_ID, msg_class)`
        to register one particular rule, or
        `nlsocket.register_policy({MSG_ID1: msg_class})`
        to register several rules at once.
        E.g.::

            policy = {RTM_NEWLINK: ifinfmsg,
                      RTM_DELLINK: ifinfmsg,
                      RTM_NEWADDR: ifaddrmsg,
                      RTM_DELADDR: ifaddrmsg}
            nlsocket.register_policy(policy)

        One can call `register_policy()` as many times,
        as one want to -- it will just extend the current
        policy scheme, not replace it.
        '''
        if isinstance(policy, int) and msg_class is not None:
            policy = {policy: msg_class}

        assert isinstance(policy, dict)
        for key in policy:
            self.marshal.msg_map[key] = policy[key]

        return self.marshal.msg_map

    def unregister_policy(self, policy):
        '''
        Unregister policy. Policy can be:

            - int -- then it will just remove one policy
            - list or tuple of ints -- remove all given
            - dict -- remove policies by keys from dict

        In the last case the routine will ignore dict values,
        it is implemented so just to make it compatible with
        `get_policy_map()` return value.
        '''
        if isinstance(policy, int):
            policy = [policy]
        elif isinstance(policy, dict):
            policy = list(policy)

        assert isinstance(policy, (tuple, list, set))

        for key in policy:
            del self.marshal.msg_map[key]

        return self.marshal.msg_map

    def get_policy_map(self, policy=None):
        '''
        Return policy for a given message type or for all
        message types. Policy parameter can be either int,
        or a list of ints. Always return dictionary.
        '''
        if policy is None:
            return self.marshal.msg_map

        if isinstance(policy, int):
            policy = [policy]

        assert isinstance(policy, (list, tuple, set))

        ret = {}
        for key in policy:
            ret[key] = self.marshal.msg_map[key]

        return ret

    def sendto(self, *argv, **kwarg):
        return self._sendto(*argv, **kwarg)

    def recv(self, *argv, **kwarg):
        return self._recv(*argv, **kwarg)

    def async_recv(self):
        poll = select.poll()
        poll.register(self._sock, select.POLLIN | select.POLLPRI)
        poll.register(self._ctrl_read, select.POLLIN | select.POLLPRI)
        sockfd = self._sock.fileno()
        while True:
            events = poll.poll()
            for (fd, event) in events:
                if fd == sockfd:
                    try:
                        self.buffer_queue.put(self._sock.recv(1024 * 1024))
                    except Exception as e:
                        self.buffer_queue.put(e)
                else:
                    return

    def put(self, msg, msg_type,
            msg_flags=NLM_F_REQUEST,
            addr=(0, 0),
            msg_seq=0,
            msg_pid=None):
        '''
        Construct a message from a dictionary and send it to
        the socket. Parameters:

            - msg -- the message in the dictionary format
            - msg_type -- the message type
            - msg_flags -- the message flags to use in the request
            - addr -- `sendto()` addr, default `(0, 0)`
            - msg_seq -- sequence number to use
            - msg_pid -- pid to use, if `None` -- use os.getpid()

        Example::

            s = IPRSocket()
            s.bind()
            s.put({'index': 1}, RTM_GETLINK)
            s.get()
            s.close()

        Please notice, that the return value of `s.get()` can be
        not the result of `s.put()`, but any broadcast message.
        To fix that, use `msg_seq` -- the response must contain the
        same `msg['header']['sequence_number']` value.
        '''
        if msg_seq != 0:
            self.lock[msg_seq].acquire()
        try:
            if msg_seq not in self.backlog:
                self.backlog[msg_seq] = []
            if not isinstance(msg, nlmsg):
                msg_class = self.marshal.msg_map[msg_type]
                msg = msg_class(msg)
            if msg_pid is None:
                msg_pid = self.epid or os.getpid()
            msg['header']['type'] = msg_type
            msg['header']['flags'] = msg_flags
            msg['header']['sequence_number'] = msg_seq
            msg['header']['pid'] = msg_pid
            self.sendto_gate(msg, addr)
        except:
            raise
        finally:
            if msg_seq != 0:
                self.lock[msg_seq].release()

    def sendto_gate(self, msg, addr):
        msg.encode()
        self.sendto(msg.buf.getvalue(), addr)

    def get(self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None):
        '''
        Get parsed messages list. If `msg_seq` is given, return
        only messages with that `msg['header']['sequence_number']`,
        saving all other messages into `self.backlog`.

        The routine is thread-safe.

        The `bufsize` parameter can be:

            - -1: bufsize will be calculated from the first 4 bytes of
                the network data
            - 0: bufsize will be calculated from SO_RCVBUF sockopt
            - int >= 0: just a bufsize
        '''
        ctime = time.time()

        with self.lock[msg_seq]:
            if bufsize == -1:
                # get bufsize from the network data
                bufsize = struct.unpack("I", self.recv(4, MSG_PEEK))[0]
            elif bufsize == 0:
                # get bufsize from SO_RCVBUF
                bufsize = self.getsockopt(SOL_SOCKET, SO_RCVBUF) // 2

            ret = []
            enough = False
            while not enough:
                # 8<-----------------------------------------------------------
                #
                # This stage changes the backlog, so use mutex to
                # prevent side changes
                self.backlog_lock.acquire()
                ##
                # Stage 1. BEGIN
                #
                # 8<-----------------------------------------------------------
                #
                # Check backlog and return already collected
                # messages.
                #
                if msg_seq == 0 and self.backlog[0]:
                    # Zero queue.
                    #
                    # Load the backlog, if there is valid
                    # content in it
                    ret.extend(self.backlog[0])
                    self.backlog[0] = []
                    # And just exit
                    self.backlog_lock.release()
                    break
                elif self.backlog.get(msg_seq, None):
                    # Any other msg_seq.
                    #
                    # Collect messages up to the terminator.
                    # Terminator conditions:
                    #  * NLMSG_ERROR != 0
                    #  * NLMSG_DONE
                    #  * terminate() function (if defined)
                    #  * not NLM_F_MULTI
                    #
                    # Please note, that if terminator not occured,
                    # more `recv()` rounds CAN be required.
                    for msg in tuple(self.backlog[msg_seq]):

                        # Drop the message from the backlog, if any
                        self.backlog[msg_seq].remove(msg)

                        # If there is an error, raise exception
                        if msg['header'].get('error', None) is not None:
                            self.backlog[0].extend(self.backlog[msg_seq])
                            del self.backlog[msg_seq]
                            # The loop is done
                            self.backlog_lock.release()
                            raise msg['header']['error']

                        # If it is the terminator message, say "enough"
                        # and requeue all the rest into Zero queue
                        if (msg['header']['type'] == NLMSG_DONE) or \
                                (terminate is not None and terminate(msg)):
                            # The loop is done
                            enough = True

                        # If it is just a normal message, append it to
                        # the response
                        if not enough:
                            ret.append(msg)
                            # But finish the loop on single messages
                            if not msg['header']['flags'] & NLM_F_MULTI:
                                # but not multi -- so end the loop
                                enough = True

                        # Enough is enough, requeue the rest and delete
                        # our backlog
                        if enough:
                            self.backlog[0].extend(self.backlog[msg_seq])
                            del self.backlog[msg_seq]
                            break

                    # Next iteration
                    self.backlog_lock.release()
                else:
                    # Stage 1. END
                    #
                    # 8<-------------------------------------------------------
                    #
                    # Stage 2. BEGIN
                    #
                    # 8<-------------------------------------------------------
                    #
                    # Receive the data from the socket and put the messages
                    # into the backlog
                    #
                    self.backlog_lock.release()
                    ##
                    #
                    # Control the timeout. We should not be within the
                    # function more than TIMEOUT seconds. All the locks
                    # MUST be released here.
                    #
                    if time.time() - ctime > self.get_timeout:
                        if self.get_timeout_exception:
                            raise self.get_timeout_exception()
                        else:
                            return ret
                    #
                    if self.read_lock.acquire(False):
                        self.change_master.clear()
                        # If the socket is free to read from, occupy
                        # it and wait for the data
                        #
                        # This is a time consuming process, so all the
                        # locks, except the read lock must be released
                        data = self.recv(bufsize)
                        # Parse data
                        msgs = self.marshal.parse(data)
                        # Reset ctime -- timeout should be measured
                        # for every turn separately
                        ctime = time.time()
                        #
                        current = self.buffer_queue.qsize()
                        delta = current - self.qsize
                        if delta > 10:
                            delay = min(3, max(0.01, float(current) / 60000))
                            message = ("Packet burst: the reader thread "
                                       "priority is increased, beware of "
                                       "delays on netlink calls\n\tCounters: "
                                       "delta=%s qsize=%s delay=%s "
                                       % (delta, current, delay))
                            if delay < 1:
                                log.debug(message)
                            else:
                                log.warning(message)
                            time.sleep(delay)
                        self.qsize = current

                        # We've got the data, lock the backlog again
                        self.backlog_lock.acquire()
                        for msg in msgs:
                            seq = msg['header']['sequence_number']
                            if seq not in self.backlog:
                                if msg['header']['type'] == NLMSG_ERROR:
                                    # Drop orphaned NLMSG_ERROR messages
                                    continue
                                seq = 0
                            # 8<-----------------------------------------------
                            # Callbacks section
                            for cr in self.callbacks:
                                try:
                                    if cr[0](msg):
                                        cr[1](msg, *cr[2])
                                except:
                                    log.warning("Callback fail: %s" % (cr))
                                    log.warning(traceback.format_exc())
                            # 8<-----------------------------------------------
                            self.backlog[seq].append(msg)
                        # We finished with the backlog, so release the lock
                        self.backlog_lock.release()

                        # Now wake up other threads
                        self.change_master.set()

                        # Finally, release the read lock: all data processed
                        self.read_lock.release()
                    else:
                        # If the socket is occupied and there is still no
                        # data for us, wait for the next master change or
                        # for a timeout
                        self.change_master.wait(1)
                    # 8<-------------------------------------------------------
                    #
                    # Stage 2. END
                    #
                    # 8<-------------------------------------------------------

            return ret

    def nlm_request(self, msg, msg_type,
                    msg_flags=NLM_F_REQUEST | NLM_F_DUMP,
                    terminate=None,
                    exception_catch=Exception,
                    exception_handler=None):

        def do_try():
            msg_seq = self.addr_pool.alloc()
            with self.lock[msg_seq]:
                try:
                    msg.reset()
                    self.put(msg, msg_type, msg_flags, msg_seq=msg_seq)
                    ret = self.get(msg_seq=msg_seq, terminate=terminate)
                    return ret
                except Exception:
                    raise
                finally:
                    # Ban this msg_seq for 0xff rounds
                    #
                    # It's a long story. Modern kernels for RTM_SET.*
                    # operations always return NLMSG_ERROR(0) == success,
                    # even not setting NLM_F_MULTY flag on other response
                    # messages and thus w/o any NLMSG_DONE. So, how to detect
                    # the response end? One can not rely on NLMSG_ERROR on
                    # old kernels, but we have to support them too. Ty, we
                    # just ban msg_seq for several rounds, and NLMSG_ERROR,
                    # being received, will become orphaned and just dropped.
                    #
                    # Hack, but true.
                    self.addr_pool.free(msg_seq, ban=0xff)

        while True:
            try:
                return do_try()
            except exception_catch as e:
                if exception_handler and not exception_handler(e):
                    continue
                raise
            except Exception:
                raise
コード例 #9
0
ファイル: test_addrpool.py プロジェクト: 0x90/pyroute2
    def test_free(self):

        ap = AddrPool(minaddr=1, maxaddr=1024)
        f = ap.alloc()
        ap.free(f)
コード例 #10
0
ファイル: test_addrpool.py プロジェクト: welterde/pyroute2
    def test_free(self):

        ap = AddrPool(minaddr=1, maxaddr=1024)
        f = ap.alloc()
        ap.free(f)