Ejemplo n.º 1
0
class RTNL_Object(dict):

    view = None  # (optional) view to load values for the summary etc.
    utable = None  # table to send updates to

    key_extra_fields = []
    hidden_fields = []
    fields_cmp = {}

    schema = None
    event_map = None
    state = None
    log = None
    errors = None
    msg_class = None
    reverse_update = None
    _table = None
    _apply_script = None
    _key = None
    _replace = None
    _replace_on_key_change = False

    # 8<------------------------------------------------------------
    #
    # Documented public properties section
    #
    @property
    def table(self):
        '''
        The main reference table for the object. The SQL schema of the
        table is used to build the key and to verify the fields.
        '''
        return self._table

    @table.setter
    def table(self, value):
        self._table = value

    @property
    def etable(self):
        '''
        The table where the object actually fetches the data from. It
        is not always equal `self.table`, e.g. snapshot objects fetch
        the data from snapshot tables.

        Read-only property.
        '''
        if self.ctxid:
            return '%s_%s' % (self.table, self.ctxid)
        else:
            return self.table

    @property
    def key(self):
        '''
        The key of the object, used to build SQL requests to fetch
        the data from the DB.
        '''
        nkey = self._key or {}
        ret = collections.OrderedDict()
        for name in self.kspec:
            kname = self.iclass.nla2name(name)
            if kname in self:
                value = self[kname]
                if value is None and name in nkey:
                    value = nkey[name]
                if isinstance(value, (list, tuple, dict)):
                    value = json.dumps(value)
                ret[name] = value
        if len(ret) < len(self.kspec):
            for name in self.key_extra_fields:
                kname = self.iclass.nla2name(name)
                if self.get(kname):
                    ret[name] = self[kname]
        return ret

    @key.setter
    def key(self, k):
        if not isinstance(k, dict):
            return
        for key, value in k.items():
            if value is not None:
                dict.__setitem__(self, self.iclass.nla2name(key), value)

    #
    # 8<------------------------------------------------------------
    #

    @classmethod
    def _dump_where(cls, view):
        return '', []

    @classmethod
    def _sdump(cls, view, names, fnames):
        req = '''
              SELECT %s FROM %s AS main
              ''' % (fnames, cls.table)
        yield names
        where, values = cls._dump_where(view)
        for record in view.ndb.schema.fetch(req + where, values):
            yield record

    @classmethod
    def summary(cls, view):
        return cls._sdump(view,
                          view.ndb.schema.compiled[cls.table]['norm_idx'],
                          view.ndb.schema.compiled[cls.table]['knames'])

    @classmethod
    def dump(cls, view):
        return cls._sdump(view,
                          view.ndb.schema.compiled[cls.table]['norm_names'],
                          view.ndb.schema.compiled[cls.table]['fnames'])

    def __init__(self, view, key, iclass, ctxid=None, load=True, master=None):
        self.view = view
        self.ndb = view.ndb
        self.sources = view.ndb.sources
        self.master = master
        self.ctxid = ctxid
        self.schema = view.ndb.schema
        self.changed = set()
        self.iclass = iclass
        self.utable = self.utable or self.table
        self.errors = []
        self.atime = time.time()
        self.log = self.ndb.log.channel('rtnl_object')
        self.log.debug('init')
        self.state = State()
        self.state.set('invalid')
        self.snapshot_deps = []
        self.load_event = threading.Event()
        self.load_event.set()
        self.lock = threading.Lock()
        self.kspec = self.schema.compiled[self.table]['idx']
        self.knorm = self.schema.compiled[self.table]['norm_idx']
        self.spec = self.schema.compiled[self.table]['all_names']
        self.names = self.schema.compiled[self.table]['norm_names']
        self._apply_script = []
        if isinstance(key, dict):
            self.chain = key.pop('ndb_chain', None)
            create = key.pop('create', False)
        else:
            self.chain = None
            create = False
        ckey = self.complete_key(key)
        if create:
            if ckey is not None:
                raise KeyError('object exists')
            for name in key:
                self[name] = key[name]
            # FIXME -- merge with complete_key()
            if 'target' not in self:
                self.load_value('target', self.view.default_target)
        else:
            if ckey is None:
                raise KeyError('object does not exists')
            self.key = ckey
            if load:
                if ctxid is None:
                    self.load_sql()
                else:
                    self.load_sql(table=self.table)

    def mark_tflags(self, mark):
        pass

    def keys(self):
        return filter(lambda x: x not in self.hidden_fields, dict.keys(self))

    def items(self):
        return filter(lambda x: x[0] not in self.hidden_fields,
                      dict.items(self))

    @property
    def context(self):
        return {'target': self.get('target', 'localhost')}

    @classmethod
    def adjust_spec(cls, spec, context):
        ret = dict(spec)
        if context is None:
            context = {}
        ret.update(context)
        return ret

    @classmethod
    def nla2name(self, name):
        return self.msg_class.nla2name(name)

    @classmethod
    def name2nla(self, name):
        return self.msg_class.name2nla(name)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.commit()

    def __hash__(self):
        return id(self)

    def __setitem__(self, key, value):
        if self.state == 'system' and key in self.knorm:
            if self._replace_on_key_change:
                self.log.debug('prepare replace for key %s' % (self.key))
                self._replace = type(self)(self.view, self.key)
                self.state.set('replace')
            else:
                raise ValueError('attempt to change a key field (%s)' % key)
        if key in ('net_ns_fd', 'net_ns_pid'):
            self.state.set('setns')
        if value != self.get(key, None):
            if key != 'target':
                self.changed.add(key)
            dict.__setitem__(self, key, value)

    def fields(self, *argv):
        Fields = collections.namedtuple('Fields', argv)
        return Fields(*[self[key] for key in argv])

    def key_repr(self):
        return repr(self.key)

    @cli.change_pointer
    def create(self, **spec):
        spec['create'] = True
        spec['ndb_chain'] = self
        return self.view[spec]

    @cli.show_result
    def show(self, **kwarg):
        '''
        Return the object in a specified format. The format may be
        specified with the keyword argument `format` or in the
        `ndb.config['show_format']`.

        TODO: document different formats
        '''
        fmt = kwarg.pop(
            'format',
            kwarg.pop('fmt', self.view.ndb.config.get('show_format',
                                                      'native')))
        if fmt == 'native':
            return dict(self)
        else:
            out = collections.OrderedDict()
            for key in sorted(self):
                out[key] = self[key]
            return '%s\n' % json.dumps(out, indent=4, separators=(',', ': '))

    def set(self, key, value):
        '''
        Set a field specified by `key` to `value`, and return self. The
        method is useful to write call chains like that::

            (ndb
             .interfaces["eth0"]
             .set('mtu', 1200)
             .set('state', 'up')
             .set('address', '00:11:22:33:44:55')
             .commit())
        '''
        self[key] = value
        return self

    def wtime(self, itn=1):
        return max(min(itn * 0.1, 1), self.view.ndb._event_queue.qsize() / 10)

    def register(self):
        #
        # Construct a weakref handler for events.
        #
        # If the referent doesn't exist, raise the
        # exception to remove the handler from the
        # chain.
        #
        def wr_handler(wr, fname, *argv):
            try:
                return getattr(wr(), fname)(*argv)
            except:
                # check if the weakref became invalid
                if wr() is None:
                    raise InvalidateHandlerException()
                raise

        wr = weakref.ref(self)
        for event, fname in self.event_map.items():
            #
            # Do not trust the implicit scope and pass the
            # weakref explicitly via partial
            #
            (self.ndb.register_handler(event, partial(wr_handler, wr, fname)))

    def snapshot(self, ctxid=None):
        '''
        Create and return a snapshot of the object. The method creates
        corresponding SQL tables for the object itself and for detected
        dependencies.

        The snapshot tables will be removed as soon as the snapshot gets
        collected by the GC.
        '''
        ctxid = ctxid or self.ctxid or id(self)
        if self._replace is None:
            key = self.key
        else:
            key = self._replace.key
        snp = type(self)(self.view, key, ctxid=ctxid)
        self.ndb.schema.save_deps(ctxid, weakref.ref(snp), self.iclass)
        snp.changed = set(self.changed)
        return snp

    def complete_key(self, key):
        '''
        Try to complete the object key based on the provided fields.
        E.g.::

            >>> ndb.interfaces['eth0'].complete_key({"ifname": "eth0"})
            {'ifname': 'eth0',
             'index': 2,
             'target': u'localhost',
             'tflags': 0}

        It is an internal method and is not supposed to be used externally.
        '''
        self.log.debug('complete key %s from table %s' % (key, self.etable))
        fetch = []
        if isinstance(key, Record):
            key = key._as_dict()
        for name in self.kspec:
            if name not in key:
                fetch.append('f_%s' % name)

        if fetch:
            keys = []
            values = []
            for name, value in key.items():
                nla_name = self.iclass.name2nla(name)
                if nla_name in self.spec:
                    name = nla_name
                if value is not None and name in self.spec:
                    keys.append('f_%s = %s' % (name, self.schema.plch))
                    values.append(value)
            spec = (self.ndb.schema.fetchone(
                'SELECT %s FROM %s WHERE %s' %
                (' , '.join(fetch), self.etable, ' AND '.join(keys)), values))
            if spec is None:
                self.log.debug('got none')
                return None
            for name, value in zip(fetch, spec):
                key[name[2:]] = value

        self.log.debug('got %s' % key)
        return key

    def rollback(self, snapshot=None):
        '''
        Try to rollback the object state using the snapshot provided as
        an argument or using `self.last_save`.
        '''
        if self._replace is not None:
            self.log.debug('rollback replace: %s :: %s' %
                           (self.key, self._replace.key))
            new_replace = type(self)(self.view, self.key)
            new_replace.state.set('remove')
            self.state.set('replace')
            self.update(self._replace)
            self._replace = new_replace
        self.log.debug('rollback: %s' % str(self.state.events))
        snapshot = snapshot or self.last_save
        snapshot.state.set(self.state.get())
        snapshot.apply(rollback=True)
        for link, snp in snapshot.snapshot_deps:
            link.rollback(snapshot=snp)
        return self

    def clear(self):
        pass

    @property
    def clean(self):
        return self.state == 'system' and \
            not self.changed and \
            not self._apply_script

    def commit(self):
        '''
        Try to commit the pending changes. If the commit fails,
        automatically revert the state.
        '''
        if self.clean:
            return self

        if self.chain:
            self.chain.commit()
        self.log.debug('commit: %s' % str(self.state.events))
        # Is it a new object?
        if self.state == 'invalid':
            # Save values, try to apply
            save = dict(self)
            try:
                return self.apply()
            except Exception as e_i:
                # Save the debug info
                e_i.trace = traceback.format_exc()
                # ACHTUNG! The routine doesn't clean up the system
                #
                # Drop all the values and rollback to the initial state
                for key in tuple(self.keys()):
                    del self[key]
                for key in save:
                    dict.__setitem__(self, key, save[key])
                raise e_i

        # Continue with an existing object

        # The snapshot tables in the DB will be dropped as soon as the GC
        # collects the object. But in the case of an exception the `snp`
        # variable will be saved in the traceback, so the tables will be
        # available to debug. If the traceback will be saved somewhere then
        # the tables will never be dropped by the GC, so you can do it
        # manually by `ndb.schema.purge_snapshots()` -- to invalidate all
        # the snapshots and to drop the associated tables.

        # Apply the changes
        try:
            self.apply()
        except Exception as e_c:
            # Rollback in the case of any error
            try:
                self.rollback()
            except Exception as e_r:
                e_c.chain = [e_r]
                if hasattr(e_r, 'chain'):
                    e_c.chain.extend(e_r.chain)
                e_r.chain = None
            raise
        finally:
            (self.last_save.state.set(self.state.get()))
        if self._replace is not None:
            self._replace = None
        return self

    def remove(self):
        with self.lock:
            self.state.set('remove')
            return self

    def check(self):
        state_map = (('invalid', 'system'), ('remove', 'invalid'),
                     ('setns', 'invalid'), ('setns', 'system'), ('replace',
                                                                 'system'))

        self.load_sql()
        self.log.debug('check: %s' % str(self.state.events))

        if self.state.transition() not in state_map:
            self.log.debug('check state: False')
            return False

        if self.changed:
            self.log.debug('check changed: %s' % (self.changed))
            return False

        self.log.debug('check: True')
        return True

    def make_req(self, prime):
        req = dict(prime)
        for key in self.changed:
            req[key] = self[key]
        return req

    def make_idx_req(self, prime):
        return prime

    def get_count(self):
        conditions = []
        values = []
        for name in self.kspec:
            conditions.append('f_%s = %s' % (name, self.schema.plch))
            values.append(self.get(self.iclass.nla2name(name), None))
        return (self.ndb.schema.fetchone(
            '''
                          SELECT count(*) FROM %s WHERE %s
                          ''' % (self.table, ' AND '.join(conditions)),
            values))[0]

    def hook_apply(self, method, **spec):
        pass

    def apply(self, rollback=False):
        '''
        Create a snapshot and apply pending changes. Do not revert
        the changes in the case of an exception.
        '''

        # Save the context
        if not rollback and self.state != 'invalid':
            self.last_save = self.snapshot()

        self.log.debug('apply: %s' % str(self.state.events))
        self.load_event.clear()
        self._apply_script_snapshots = []

        # Load the current state
        try:
            self.schema.commit()
        except:
            pass
        self.load_sql(set_state=False)
        if self.state == 'system' and self.get_count() == 0:
            state = self.state.set('invalid')
        else:
            state = self.state.get()

        # Create the request.
        idx_req = dict([(x, self[self.iclass.nla2name(x)])
                        for x in self.schema.compiled[self.table]['idx']
                        if self.iclass.nla2name(x) in self])
        req = self.make_req(idx_req)
        idx_req = self.make_idx_req(idx_req)
        self.log.debug('apply req: %s' % str(req))
        self.log.debug('apply idx_req: %s' % str(idx_req))
        self.log.debug('apply state: %s' % state)

        method = None
        ignore = tuple()
        #
        if state in ('invalid', 'replace'):
            for k, v in tuple(self.items()):
                if k not in req and v is not None:
                    req[k] = v
            if self.master is not None:
                req = self.adjust_spec(req, self.master.context)
            method = 'add'
            ignore = {errno.EEXIST: 'set'}
        elif state == 'system':
            method = 'set'
        elif state == 'setns':
            method = 'set'
            ignore = {errno.ENODEV: None}
        elif state == 'remove':
            method = 'del'
            req = idx_req
            ignore = {
                errno.ENODEV: None,  # interfaces
                errno.ESRCH: None,  # routes
                errno.EADDRNOTAVAIL: None
            }  # addresses
        else:
            raise Exception('state transition not supported')

        for itn in range(20):
            try:
                self.log.debug('run %s (%s)' % (method, req))
                (self.sources[self['target']].api(self.api, method, **req))
                (self.hook_apply(method, **req))
            except NetlinkError as e:
                (self.log.debug('error: %s' % e))
                if e.code in ignore:
                    self.log.debug('ignore error %s for %s' % (e.code, self))
                    if ignore[e.code] is not None:
                        self.log.debug('run fallback %s (%s)' %
                                       (ignore[e.code], req))
                        try:
                            (self.sources[self['target']].api(
                                self.api, ignore[e.code], **req))
                        except NetlinkError:
                            pass
                else:
                    raise e

            wtime = self.wtime(itn)
            mqsize = self.view.ndb._event_queue.qsize()
            nq = self.schema.stats.get(self['target'])
            if nq is not None:
                nqsize = nq.qsize
            else:
                nqsize = 0
            self.log.debug('stats: apply %s {'
                           'objid %s, wtime %s, '
                           'mqsize %s, nqsize %s'
                           '}' % (method, id(self), wtime, mqsize, nqsize))
            if self.check():
                self.log.debug('checked')
                break
            self.log.debug('check failed')
            self.load_event.wait(wtime)
            self.load_event.clear()
        else:
            self.log.debug('stats: %s apply %s fail' % (id(self), method))
            raise Exception('lost sync in apply()')

        self.log.debug('stats: %s pass' % (id(self)))
        #
        if state == 'replace':
            self._replace.remove()
            self._replace.apply()
        #
        if rollback:
            #
            # Iterate all the snapshot tables and collect the diff
            for cls in self.view.classes.values():
                if issubclass(type(self), cls) or \
                        issubclass(cls, type(self)):
                    continue
                table = cls.table
                # comprare the tables
                diff = (self.ndb.schema.fetch('''
                               SELECT * FROM %s_%s
                                 EXCEPT
                               SELECT * FROM %s
                               ''' % (table, self.ctxid, table)))
                for record in diff:
                    record = dict(
                        zip((self.schema.compiled[table]['all_names']),
                            record))
                    key = dict([
                        x for x in record.items()
                        if x[0] in self.schema.compiled[table]['idx']
                    ])
                    obj = self.view.get(key, table)
                    obj.load_sql(ctxid=self.ctxid)
                    obj.state.set('invalid')
                    try:
                        obj.apply()
                    except Exception as e:
                        self.errors.append((time.time(), obj, e))
        else:
            for op, argv, kwarg in self._apply_script:
                ret = op(*argv, **kwarg)
                if isinstance(ret, Exception):
                    raise ret
                elif ret is not None:
                    self._apply_script_snapshots.append(ret)
            self._apply_script = []
        return self

    def update(self, data):
        for key, value in data.items():
            self.load_value(key, value)

    def load_value(self, key, value):
        '''
        Load a value and clean up the `self.changed` set if the
        loaded value matches the expectation.
        '''
        if key not in self.changed:
            dict.__setitem__(self, key, value)
        elif self.get(key) == value:
            self.changed.remove(key)
        elif key in self.fields_cmp and self.fields_cmp[key](self, value):
            self.changed.remove(key)

    def load_sql(self, table=None, ctxid=None, set_state=True):
        '''
        Load the data from the database.
        '''

        if not self.key:
            return

        if table is None:
            if ctxid is None:
                table = self.etable
            else:
                table = '%s_%s' % (self.table, ctxid)
        keys = []
        values = []

        for name, value in self.key.items():
            keys.append('f_%s = %s' % (name, self.schema.plch))
            if isinstance(value, (list, tuple, dict)):
                value = json.dumps(value)
            values.append(value)

        spec = (self.ndb.schema.fetchone(
            'SELECT * FROM %s WHERE %s' % (table, ' AND '.join(keys)), values))
        self.log.debug('load_sql: %s' % str(spec))
        if set_state:
            with self.lock:
                if spec is None:
                    if self.state != 'invalid':
                        # No such object (anymore)
                        self.state.set('invalid')
                        self.changed = set()
                elif self.state not in ('remove', 'setns'):
                    self.update(dict(zip(self.names, spec)))
                    self.state.set('system')
        return spec

    def load_rtnlmsg(self, target, event):
        '''
        Check if the RTNL event matches the object and load the
        data from the database if it does.
        '''
        # TODO: partial match (object rename / restore)
        # ...

        # full match
        for name in self.knorm:
            value = self.get(name)
            if name == 'target':
                if value != target:
                    return
            elif name == 'tflags':
                continue
            elif value != (event.get_attr(name) or event.get(name)):
                return

        self.log.debug('load_rtnl: %s' % str(event.get('header')))
        if event['header'].get('type', 0) % 2:
            self.state.set('invalid')
            self.changed = set()
        else:
            self.load_sql()
        self.load_event.set()
Ejemplo n.º 2
0
class Source(dict):
    '''
    The RNTL source. The source that is used to init the object
    must comply to IPRoute API, must support the async_cache. If
    the source starts additional threads, they must be joined
    in the source.close()
    '''
    table_alias = 'src'
    dump = None
    dump_header = None
    summary = None
    summary_header = None
    view = None
    table = 'sources'
    vmap = {'local': IPRoute, 'netns': NetNS, 'remote': RemoteIPRoute}

    def __init__(self, ndb, **spec):
        self.th = None
        self.nl = None
        self.ndb = ndb
        self.evq = self.ndb._event_queue
        # the target id -- just in case
        self.target = spec.pop('target')
        kind = spec.pop('kind', 'local')
        self.persistent = spec.pop('persistent', True)
        self.event = spec.pop('event')
        if not self.event:
            self.event = SyncStart()
        # RTNL API
        self.nl_prime = self.vmap[kind]
        self.nl_kwarg = spec
        #
        self.shutdown = threading.Event()
        self.started = threading.Event()
        self.lock = threading.RLock()
        self.started.clear()
        self.state = State()
        self.log = Log()
        self.state.set('init')
        self.ndb.schema.execute(
            '''
                                INSERT INTO sources (f_target, f_kind)
                                VALUES (%s, %s)
                                ''' %
            (self.ndb.schema.plch, self.ndb.schema.plch), (self.target, kind))
        for key, value in spec.items():
            vtype = 'int' if isinstance(value, int) else 'str'
            self.ndb.schema.execute(
                '''
                                    INSERT INTO options (f_target,
                                                         f_name,
                                                         f_type,
                                                         f_value)
                                    VALUES (%s, %s, %s, %s)
                                    ''' %
                (self.ndb.schema.plch, self.ndb.schema.plch,
                 self.ndb.schema.plch, self.ndb.schema.plch),
                (self.target, key, vtype, value))

        self.load_sql()

    @classmethod
    def defaults(cls, spec):
        ret = dict(spec)
        defaults = {}
        if 'hostname' in spec:
            defaults['kind'] = 'remote'
            defaults['protocol'] = 'ssh'
            defaults['target'] = spec['hostname']
        if 'netns' in spec:
            defaults['kind'] = 'netns'
            defaults['target'] = spec['netns']
        for key in defaults:
            if key not in ret:
                ret[key] = defaults[key]
        return ret

    def remove(self):
        with self.lock:
            self.close()
            (self.ndb.schema.execute(
                '''
                      DELETE FROM sources WHERE f_target=%s
                      ''' % self.ndb.schema.plch, (self.target, )))
            (self.ndb.schema.execute(
                '''
                      DELETE FROM options WHERE f_target=%s
                      ''' % self.ndb.schema.plch, (self.target, )))
            return self

    def __repr__(self):
        if isinstance(self.nl_prime, NetlinkMixin):
            name = self.nl_prime.__class__.__name__
        elif isinstance(self.nl_prime, type):
            name = self.nl_prime.__name__

        return '[%s] <%s %s>' % (self.state.get(), name, self.nl_kwarg)

    @classmethod
    def nla2name(cls, name):
        return name

    @classmethod
    def name2nla(cls, name):
        return name

    def api(self, name, *argv, **kwarg):
        for _ in range(100):  # FIXME make a constant
            with self.lock:
                try:
                    return getattr(self.nl, name)(*argv, **kwarg)
                except (NetlinkError, AttributeError, ValueError, KeyError):
                    raise
                except Exception as e:
                    # probably the source is restarting
                    log.debug('[%s] source api error: %s' % (self.target, e))
                    time.sleep(1)
        raise RuntimeError('api call failed')

    def receiver(self):
        #
        # The source thread routine -- get events from the
        # channel and forward them into the common event queue
        #
        # The routine exists on an event with error code == 104
        #
        while True:
            with self.lock:
                if self.shutdown.is_set():
                    log.debug('[%s] stopped' % self.target)
                    self.state.set('stopped')
                    return

                if self.nl is not None:
                    try:
                        self.nl.close(code=0)
                    except Exception as e:
                        self.log.append('restarting the nl transport')
                        log.warning('[%s] source restart: %s' %
                                    (self.target, e))
                try:
                    self.log.append('connecting')
                    log.debug('[%s] connecting' % self.target)
                    self.state.set('connecting')
                    if isinstance(self.nl_prime, type):
                        self.nl = self.nl_prime(**self.nl_kwarg)
                    else:
                        raise TypeError('source channel not supported')
                    self.log.append('loading')
                    log.debug('[%s] loading' % self.target)
                    self.state.set('loading')
                    #
                    self.nl.bind(async_cache=True, clone_socket=True)
                    #
                    # Initial load -- enqueue the data
                    #
                    self.ndb.schema.allow_read(False)
                    try:
                        self.ndb.schema.flush(self.target)
                        self.evq.put((self.target, self.nl.get_links()))
                        self.evq.put((self.target, self.nl.get_addr()))
                        self.evq.put((self.target, self.nl.get_neighbours()))
                        self.evq.put((self.target, self.nl.get_routes()))
                    finally:
                        self.ndb.schema.allow_read(True)
                    self.started.set()
                    self.shutdown.clear()
                    self.log.append('running')
                    log.debug('[%s] running' % self.target)
                    self.state.set('running')
                    if self.event is not None:
                        self.evq.put((self.target, (self.event, )))
                except TypeError:
                    raise
                except Exception as e:
                    self.started.set()
                    self.log.append('failed')
                    log.debug('[%s] failed' % self.target)
                    self.state.set('failed')
                    log.error('[%s] source error: %s %s' %
                              (self.target, type(e), e))
                    self.evq.put((self.target, (MarkFailed(), )))
                    if self.persistent:
                        self.log.append('sleeping before restart')
                        log.debug('[%s] sleeping before restart' % self.target)
                        self.shutdown.wait(SOURCE_FAIL_PAUSE)
                        if self.shutdown.is_set():
                            self.log.append('source shutdown')
                            log.debug('[%s] source shutdown' % self.target)
                            return
                    else:
                        return
                    continue

            while True:
                try:
                    msg = tuple(self.nl.get())
                except Exception as e:
                    self.log.append('error: %s %s' % (type(e), e))
                    log.error('[%s] source error: %s %s' %
                              (self.target, type(e), e))
                    msg = None
                    if self.persistent:
                        break

                if msg is None or \
                        msg[0]['header']['error'] and \
                        msg[0]['header']['error'].code == 104:
                    self.log.append('stopped')
                    log.debug('[%s] stopped' % self.target)
                    self.state.set('stopped')
                    # thus we make sure that all the events from
                    # this source are consumed by the main loop
                    # in __dbm__() routine
                    sync = threading.Event()
                    self.evq.put((self.target, (sync, )))
                    sync.wait()
                    return

                self.ndb.schema._allow_write.wait()
                self.evq.put((self.target, msg))

    def start(self):

        #
        # Start source thread
        with self.lock:
            self.log.append('starting the source')
            if (self.th is not None) and self.th.is_alive():
                raise RuntimeError('source is running')

            self.th = (threading.Thread(target=self.receiver,
                                        name='NDB event source: %s' %
                                        (self.target)))
            self.th.start()
            return self

    def close(self):
        with self.lock:
            self.log.append('stopping the source')
            self.shutdown.set()
            if self.nl is not None:
                try:
                    self.nl.close()
                except Exception as e:
                    log.error('[%s] source close: %s' % (self.target, e))
        if self.th is not None:
            self.th.join()
        self.log.append('flushing the DB for the target')
        self.ndb.schema.flush(self.target)

    def restart(self):
        with self.lock:
            if not self.shutdown.is_set():
                self.log.append('restarting the source')
                self.evq.put((self.target, (SchemaReadLock(), )))
                try:
                    self.close()
                    self.start()
                finally:
                    self.evq.put((self.target, (SchemaReadUnlock(), )))

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def load_sql(self):
        #
        spec = self.ndb.schema.fetchone(
            '''
                                        SELECT * FROM sources
                                        WHERE f_target = %s
                                        ''' % self.ndb.schema.plch,
            (self.target, ))
        self['target'], self['kind'] = spec
        for spec in self.ndb.schema.fetch(
                '''
                                          SELECT * FROM options
                                          WHERE f_target = %s
                                          ''' % self.ndb.schema.plch,
            (self.target, )):
            f_target, f_name, f_type, f_value = spec
            self[f_name] = int(f_value) if f_type == 'int' else f_value
Ejemplo n.º 3
0
class RTNL_Object(dict):
    '''
    Base RTNL object class.
    '''

    table = None  # model table -- always one of the main tables
    view = None  # (optional) view to load values for the summary etc.
    utable = None  # table to send updates to
    table_alias = ''

    key = None
    key_extra_fields = []

    schema = None
    event_map = None
    state = None
    log = None
    summary = None
    summary_header = None
    dump = None
    dump_header = None
    errors = None
    msg_class = None
    reverse_update = None
    _replace = None

    def __init__(self,
                 view,
                 key,
                 iclass,
                 ctxid=None,
                 match_src=None,
                 match_pairs=None):
        self.view = view
        self.ndb = view.ndb
        self.sources = view.ndb.sources
        self.ctxid = ctxid
        self.schema = view.ndb.schema
        self.match_src = match_src or tuple()
        self.match_pairs = match_pairs or dict()
        self.changed = set()
        self.iclass = iclass
        self.utable = self.utable or self.table
        self.errors = []
        self.log = self.ndb.log.channel('rtnl_object')
        self.log.debug('init')
        self.state = State()
        self.state.set('invalid')
        self.snapshot_deps = []
        self.load_event = threading.Event()
        self.lock = threading.Lock()
        self.kspec = self.schema.compiled[self.table]['idx']
        self.knorm = self.schema.compiled[self.table]['norm_idx']
        self.spec = self.schema.compiled[self.table]['all_names']
        self.names = self.schema.compiled[self.table]['norm_names']
        if isinstance(key, dict):
            self.chain = key.pop('ndb_chain', None)
            create = key.pop('create', False)
        else:
            self.chain = None
            create = False
        ckey = self.complete_key(key)
        if create and ckey is not None:
            raise KeyError('object exists')
        elif not create and ckey is None:
            raise KeyError('object does not exists')
        elif create:
            for name in key:
                self[name] = key[name]
            # FIXME -- merge with complete_key()
            if 'target' not in self:
                self.load_value('target', 'localhost')
        elif ctxid is None:
            self.key = ckey
            self.load_sql()
        else:
            self.key = ckey
            self.load_sql(table=self.table)

    @classmethod
    def nla2name(self, name):
        return self.msg_class.nla2name(name)

    @classmethod
    def name2nla(self, name):
        return self.msg_class.name2nla(name)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.commit()

    def __hash__(self):
        return id(self)

    def __getitem__(self, key):
        if key in self.match_pairs:
            for src in self.match_src:
                try:
                    return src[self.match_pairs[key]]
                except:
                    pass
        return dict.__getitem__(self, key)

    def __setitem__(self, key, value):
        if self.state == 'system' and key in self.knorm:
            raise ValueError('Attempt to change a key field (%s)' % key)
        if key in ('net_ns_fd', 'net_ns_pid'):
            self.state.set('setns')
        if value != self.get(key, None):
            self.changed.add(key)
            dict.__setitem__(self, key, value)

    def fields(self, *argv):
        Fields = collections.namedtuple('Fields', argv)
        return Fields(*[self[key] for key in argv])

    def key_repr(self):
        return repr(self.key)

    @cli.change_pointer
    def create(self, **spec):
        spec['create'] = True
        spec['ndb_chain'] = self
        return self.view[spec]

    @cli.show_result
    def show(self, **kwarg):
        fmt = kwarg.pop(
            'format',
            kwarg.pop('fmt', self.view.ndb.config.get('show_format',
                                                      'native')))
        if fmt == 'native':
            return dict(self)
        else:
            out = collections.OrderedDict()
            for key in sorted(self):
                out[key] = self[key]
            return '%s\n' % json.dumps(out, indent=4, separators=(',', ': '))

    def set(self, key, value):
        self[key] = value
        return self

    def wtime(self, itn=1):
        return max(min(itn * 0.1, 1), self.view.ndb._event_queue.qsize() / 10)

    @property
    def etable(self):
        if self.ctxid:
            return '%s_%s' % (self.table, self.ctxid)
        else:
            return self.table

    @property
    def key(self):
        ret = collections.OrderedDict()
        for name in self.kspec:
            kname = self.iclass.nla2name(name)
            if kname in self:
                ret[name] = self[kname]
        if len(ret) < len(self.kspec):
            for name in self.key_extra_fields:
                kname = self.iclass.nla2name(name)
                if self.get(kname):
                    ret[name] = self[kname]
        return ret

    @key.setter
    def key(self, k):
        if not isinstance(k, dict):
            return
        for key, value in k.items():
            if value is not None:
                dict.__setitem__(self, self.iclass.nla2name(key), value)

    def snapshot(self, ctxid=None):
        ctxid = ctxid or self.ctxid or id(self)
        if self._replace is None:
            key = self.key
        else:
            key = self._replace.key
        snp = type(self)(self.view, key, ctxid=ctxid)
        self.ndb.schema.save_deps(ctxid, weakref.ref(snp), self.iclass)
        snp.changed = set(self.changed)
        return snp

    def complete_key(self, key):
        fetch = []
        for name in self.kspec:
            if name not in key:
                fetch.append('f_%s' % name)

        if fetch:
            keys = []
            values = []
            for name, value in key.items():
                nla_name = self.iclass.name2nla(name)
                if nla_name in self.spec:
                    name = nla_name
                if value is not None and name in self.spec:
                    keys.append('f_%s = %s' % (name, self.schema.plch))
                    values.append(value)
            spec = (self.ndb.schema.fetchone(
                'SELECT %s FROM %s WHERE %s' %
                (' , '.join(fetch), self.etable, ' AND '.join(keys)), values))
            if spec is None:
                return None
            for name, value in zip(fetch, spec):
                key[name[2:]] = value

        return key

    def rollback(self, snapshot=None):
        if self._replace is not None:
            self.log.debug('rollback replace: %s :: %s' %
                           (self.key, self._replace.key))
            new_replace = type(self)(self.view, self.key)
            new_replace.state.set('remove')
            self.state.set('replace')
            self.update(self._replace)
            self._replace = new_replace
        self.log.debug('rollback: %s' % str(self.state.events))
        snapshot = snapshot or self.last_save
        snapshot.state.set(self.state.get())
        snapshot.apply(rollback=True)
        for link, snp in snapshot.snapshot_deps:
            link.rollback(snapshot=snp)
        return self

    def commit(self):
        if self.chain:
            self.chain.commit()
        self.log.debug('commit: %s' % str(self.state.events))
        # Is it a new object?
        if self.state == 'invalid':
            # Save values, try to apply
            save = dict(self)
            try:
                return self.apply()
            except Exception as e_i:
                # ACHTUNG! The routine doesn't clean up the system
                #
                # Drop all the values and rollback to the initial state
                for key in tuple(self.keys()):
                    del self[key]
                for key in save:
                    dict.__setitem__(self, key, save[key])
                raise e_i

        # Continue with an existing object

        # Save the context
        self.last_save = self.snapshot()

        # The snapshot tables in the DB will be dropped as soon as the GC
        # collects the object. But in the case of an exception the `snp`
        # variable will be saved in the traceback, so the tables will be
        # available to debug. If the traceback will be saved somewhere then
        # the tables will never be dropped by the GC, so you can do it
        # manually by `ndb.schema.purge_snapshots()` -- to invalidate all
        # the snapshots and to drop the associated tables.

        # Apply the changes
        try:
            self.apply()
        except Exception as e_c:
            # Rollback in the case of any error
            try:
                self.rollback()
            except Exception as e_r:
                e_c.chain = [e_r]
                if hasattr(e_r, 'chain'):
                    e_c.chain.extend(e_r.chain)
                e_r.chain = None
            raise
        finally:
            (self.last_save.state.set(self.state.get()))
        return self

    def remove(self):
        with self.lock:
            self.state.set('remove')
            return self

    def check(self):
        state_map = (('invalid', 'system'), ('remove', 'invalid'),
                     ('setns', 'invalid'), ('replace', 'system'))

        self.load_sql()
        self.log.debug('check: %s' % str(self.state.events))

        if self.state.transition() not in state_map:
            self.log.debug('check state: False')
            return False

        if self.changed:
            self.log.debug('check changed: %s' % (self.changed))
            return False

        self.log.debug('check: True')
        return True

    def make_req(self, prime):
        req = dict(prime)
        for key in self.changed:
            req[key] = self[key]
        return req

    def get_count(self):
        conditions = []
        values = []
        for name in self.kspec:
            conditions.append('f_%s = %s' % (name, self.schema.plch))
            values.append(self.get(self.iclass.nla2name(name), None))
        return (self.ndb.schema.fetchone(
            '''
                          SELECT count(*) FROM %s WHERE %s
                          ''' % (self.table, ' AND '.join(conditions)),
            values))[0]

    def hook_apply(self, method, **spec):
        pass

    def apply(self, rollback=False, fix_remove=True):

        self.log.debug('apply: %s' % str(self.state.events))
        self.load_event.clear()

        # Load the current state
        try:
            self.schema.commit()
        except:
            pass
        self.load_sql(set_state=False)
        if self.state == 'system' and self.get_count() == 0:
            state = self.state.set('invalid')
        else:
            state = self.state.get()

        # Create the request.
        idx_req = dict([(x, self[self.iclass.nla2name(x)])
                        for x in self.schema.compiled[self.table]['idx']
                        if self.iclass.nla2name(x) in self])
        req = self.make_req(idx_req)
        self.log.debug('apply req: %s' % str(req))
        self.log.debug('apply idx_req: %s' % str(idx_req))
        self.log.debug('apply state: %s' % state)

        method = None
        ignore = tuple()
        #
        if state in ('invalid', 'replace'):
            req = dict([x for x in self.items() if x[1] is not None])
            for l_key, r_key in self.match_pairs.items():
                for src in self.match_src:
                    try:
                        req[l_key] = src[r_key]
                        break
                    except:
                        pass
            method = 'add'
            ignore = {errno.EEXIST: 'set'}
        elif state == 'system':
            method = 'set'
        elif state == 'setns':
            method = 'set'
            ignore = {errno.ENODEV: None}
        elif state == 'remove':
            method = 'del'
            req = idx_req
            ignore = {errno.ENODEV: None, errno.ESRCH: None}
        else:
            raise Exception('state transition not supported')

        for itn in range(20):
            try:
                self.log.debug('run %s (%s)' % (method, req))
                (self.sources[self['target']].api(self.api, method, **req))
                (self.hook_apply(method, **req))
            except NetlinkError as e:
                (self.log.debug('error: %s' % e))
                if e.code in ignore:
                    self.log.debug('ignore error %s for %s' % (e.code, self))
                    if ignore[e.code] is not None:
                        self.log.debug('run fallback %s (%s)' %
                                       (ignore[e.code], req))
                        try:
                            (self.sources[self['target']].api(
                                self.api, ignore[e.code], **req))
                        except NetlinkError:
                            pass
                else:
                    raise e

            wtime = self.wtime(itn)
            mqsize = self.view.ndb._event_queue.qsize()
            nqsize = self.schema.stats.get(self['target']).qsize
            self.log.debug('stats: apply %s {'
                           'objid %s, wtime %s, '
                           'mqsize %s, nqsize %s'
                           '}' % (method, id(self), wtime, mqsize, nqsize))
            if self.check():
                self.log.debug('checked')
                break
            self.log.debug('check failed')
            self.load_event.wait(wtime)
            self.load_event.clear()
        else:
            self.log.debug('stats: %s apply %s fail' % (id(self), method))
            raise Exception('lost sync in apply()')

        self.log.debug('stats: %s pass' % (id(self)))
        #
        if state == 'replace':
            self._replace.remove()
            self._replace.apply()
            self._replace = None
        #
        if rollback:
            #
            # Iterate all the snapshot tables and collect the diff
            for (table, cls) in self.view.classes.items():
                if issubclass(type(self), cls) or \
                        issubclass(cls, type(self)):
                    continue
                # comprare the tables
                diff = (self.ndb.schema.fetch('''
                               SELECT * FROM %s_%s
                                 EXCEPT
                               SELECT * FROM %s
                               ''' % (table, self.ctxid, table)))
                for record in diff:
                    record = dict(
                        zip((self.schema.compiled[table]['all_names']),
                            record))
                    key = dict([
                        x for x in record.items()
                        if x[0] in self.schema.compiled[table]['idx']
                    ])
                    obj = self.view.get(key, table)
                    obj.load_sql(ctxid=self.ctxid)
                    obj.state.set('invalid')
                    try:
                        obj.apply()
                    except Exception as e:
                        self.errors.append((time.time(), obj, e))
        return self

    def update(self, data):
        for key, value in data.items():
            self.load_value(key, value)

    def load_value(self, key, value):
        if key not in self.changed:
            dict.__setitem__(self, key, value)
        elif self.get(key) == value:
            self.changed.remove(key)

    def load_sql(self, table=None, ctxid=None, set_state=True):

        if not self.key:
            return

        if table is None:
            if ctxid is None:
                table = self.etable
            else:
                table = '%s_%s' % (self.table, ctxid)
        keys = []
        values = []

        for name, value in self.key.items():
            keys.append('f_%s = %s' % (name, self.schema.plch))
            values.append(value)

        spec = (self.ndb.schema.fetchone(
            'SELECT * FROM %s WHERE %s' % (table, ' AND '.join(keys)), values))
        self.log.debug('load_sql: %s' % str(spec))
        if set_state:
            with self.lock:
                if spec is None:
                    if self.state != 'invalid':
                        # No such object (anymore)
                        self.state.set('invalid')
                        self.changed = set()
                elif self.state not in ('remove', 'setns'):
                    self.update(dict(zip(self.names, spec)))
                    self.state.set('system')
        return spec

    def load_rtnlmsg(self, target, event):
        # TODO: partial match (object rename / restore)
        # ...

        # full match
        for name, value in self.key.items():
            if name == 'target':
                if value != target:
                    return
            elif value != (event.get_attr(name) or event.get(name)):
                return

        self.log.debug('load_rtnl: %s' % str(event.get('header')))
        if event['header'].get('type', 0) % 2:
            self.state.set('invalid')
            self.changed = set()
        else:
            self.load_sql()
        self.load_event.set()
Ejemplo n.º 4
0
class Source(dict):
    '''
    The RNTL source. The source that is used to init the object
    must comply to IPRoute API, must support the async_cache. If
    the source starts additional threads, they must be joined
    in the source.close()
    '''
    table_alias = 'src'
    dump = None
    dump_header = None
    summary = None
    summary_header = None
    view = None
    table = 'sources'
    vmap = {
        'local': IPRoute,
        'netns': NetNS,
        'remote': RemoteIPRoute,
        'nsmanager': NetNSManager
    }

    def __init__(self, ndb, **spec):
        self.th = None
        self.nl = None
        self.ndb = ndb
        self.evq = self.ndb._event_queue
        # the target id -- just in case
        self.target = spec['target']
        self.kind = spec.pop('kind', 'local')
        self.persistent = spec.pop('persistent', True)
        self.event = spec.pop('event')
        # RTNL API
        self.nl_prime = self.vmap[self.kind]
        self.nl_kwarg = spec
        #
        self.shutdown = threading.Event()
        self.started = threading.Event()
        self.lock = threading.RLock()
        self.shutdown_lock = threading.RLock()
        self.started.clear()
        self.log = ndb.log.channel('sources.%s' % self.target)
        self.state = State(log=self.log)
        self.state.set('init')
        self.ndb.schema.execute(
            '''
                                INSERT INTO sources (f_target, f_kind)
                                VALUES (%s, %s)
                                ''' %
            (self.ndb.schema.plch, self.ndb.schema.plch),
            (self.target, self.kind))
        for key, value in spec.items():
            vtype = 'int' if isinstance(value, int) else 'str'
            self.ndb.schema.execute(
                '''
                                    INSERT INTO sources_options
                                    (f_target, f_name, f_type, f_value)
                                    VALUES (%s, %s, %s, %s)
                                    ''' %
                (self.ndb.schema.plch, self.ndb.schema.plch,
                 self.ndb.schema.plch, self.ndb.schema.plch),
                (self.target, key, vtype, value))

        self.load_sql()

    @classmethod
    def defaults(cls, spec):
        ret = dict(spec)
        defaults = {}
        if 'hostname' in spec:
            defaults['kind'] = 'remote'
            defaults['protocol'] = 'ssh'
            defaults['target'] = spec['hostname']
        if 'netns' in spec:
            defaults['kind'] = 'netns'
            defaults['target'] = spec['netns']
            ret['netns'] = netns._get_netnspath(spec['netns'])
        for key in defaults:
            if key not in ret:
                ret[key] = defaults[key]
        return ret

    def __repr__(self):
        if isinstance(self.nl_prime, NetlinkMixin):
            name = self.nl_prime.__class__.__name__
        elif isinstance(self.nl_prime, type):
            name = self.nl_prime.__name__

        return '[%s] <%s %s>' % (self.state.get(), name, self.nl_kwarg)

    @classmethod
    def nla2name(cls, name):
        return name

    @classmethod
    def name2nla(cls, name):
        return name

    def api(self, name, *argv, **kwarg):
        for _ in range(100):  # FIXME make a constant
            with self.lock:
                try:
                    return getattr(self.nl, name)(*argv, **kwarg)
                except (NetlinkError, AttributeError, ValueError, KeyError,
                        TypeError, socket.error, struct.error):
                    raise
                except Exception as e:
                    # probably the source is restarting
                    self.log.debug('source api error: %s' % e)
                    time.sleep(1)
        raise RuntimeError('api call failed')

    def receiver(self):
        #
        # The source thread routine -- get events from the
        # channel and forward them into the common event queue
        #
        # The routine exists on an event with error code == 104
        #
        stop = False
        while not stop:
            with self.lock:
                if self.shutdown.is_set():
                    break

                if self.nl is not None:
                    try:
                        self.nl.close(code=0)
                    except Exception as e:
                        self.log.warning('source restart: %s' % e)
                try:
                    self.state.set('connecting')
                    if isinstance(self.nl_prime, type):
                        spec = {}
                        spec.update(self.nl_kwarg)
                        if self.kind in ('nsmanager', ):
                            spec['libc'] = self.ndb.libc
                        self.nl = self.nl_prime(**spec)
                    else:
                        raise TypeError('source channel not supported')
                    self.state.set('loading')
                    #
                    self.nl.bind(async_cache=True, clone_socket=True)
                    #
                    # Initial load -- enqueue the data
                    #
                    self.ndb.schema.allow_read(False)
                    try:
                        self.ndb.schema.flush(self.target)
                        self.evq.put(self.nl.dump())
                    finally:
                        self.ndb.schema.allow_read(True)
                    self.started.set()
                    self.shutdown.clear()
                    self.state.set('running')
                    if self.event is not None:
                        self.evq.put((cmsg_event(self.target, self.event), ))
                    else:
                        self.evq.put((cmsg_sstart(self.target), ))
                except Exception as e:
                    self.started.set()
                    self.state.set('failed')
                    self.log.error('source error: %s %s' % (type(e), e))
                    try:
                        self.evq.put((cmsg_failed(self.target), ))
                    except ShutdownException:
                        stop = True
                        break
                    if self.persistent:
                        self.log.debug('sleeping before restart')
                        self.shutdown.wait(SOURCE_FAIL_PAUSE)
                        if self.shutdown.is_set():
                            self.log.debug('source shutdown')
                            stop = True
                            break
                    else:
                        self.event.set()
                        return
                    continue

            while not stop:
                try:
                    msg = tuple(self.nl.get())
                except Exception as e:
                    self.log.error('source error: %s %s' % (type(e), e))
                    msg = None
                    if not self.persistent:
                        stop = True
                    break

                code = 0
                if msg and msg[0]['header']['error']:
                    code = msg[0]['header']['error'].code

                if msg is None or code == errno.ECONNRESET:
                    stop = True
                    break

                self.ndb.schema._allow_write.wait()
                try:
                    self.evq.put(msg)
                except ShutdownException:
                    stop = True
                    break

        # thus we make sure that all the events from
        # this source are consumed by the main loop
        # in __dbm__() routine
        try:
            self.sync()
            self.log.debug('flush DB for the target')
            self.ndb.schema.flush(self.target)
        except ShutdownException:
            self.log.debug('shutdown handled by the main thread')
            pass
        self.state.set('stopped')

    def sync(self):
        self.log.debug('sync')
        sync = threading.Event()
        self.evq.put((cmsg_event(self.target, sync), ))
        sync.wait()

    def start(self):

        #
        # Start source thread
        with self.lock:
            self.log.debug('starting the source')
            if (self.th is not None) and self.th.is_alive():
                raise RuntimeError('source is running')

            self.th = (threading.Thread(target=self.receiver,
                                        name='NDB event source: %s' %
                                        (self.target)))
            self.th.start()
            return self

    def close(self, code=errno.ECONNRESET, sync=True):
        with self.shutdown_lock:
            if self.shutdown.is_set():
                self.log.debug('already stopped')
                return
            self.log.info('source shutdown')
            self.shutdown.set()
            if self.nl is not None:
                try:
                    self.nl.close(code=code)
                except Exception as e:
                    self.log.error('source close: %s' % e)
        if sync:
            if self.th is not None:
                self.th.join()
                self.th = None
            else:
                self.log.debug('receiver thread missing')

    def restart(self, reason='unknown'):
        with self.lock:
            with self.shutdown_lock:
                self.log.debug('restarting the source, reason <%s>' % (reason))
                self.started.clear()
                self.ndb.schema.allow_read(False)
                try:
                    self.close()
                    if self.th:
                        self.th.join()
                    self.shutdown.clear()
                    self.start()
                finally:
                    self.ndb.schema.allow_read(True)
        self.started.wait()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def load_sql(self):
        #
        spec = self.ndb.schema.fetchone(
            '''
                                        SELECT * FROM sources
                                        WHERE f_target = %s
                                        ''' % self.ndb.schema.plch,
            (self.target, ))
        self['target'], self['kind'] = spec
        for spec in self.ndb.schema.fetch(
                '''
                                          SELECT * FROM sources_options
                                          WHERE f_target = %s
                                          ''' % self.ndb.schema.plch,
            (self.target, )):
            f_target, f_name, f_type, f_value = spec
            self[f_name] = int(f_value) if f_type == 'int' else f_value
Ejemplo n.º 5
0
class RTNL_Object(dict):
    '''
    Base RTNL object class.
    '''

    table = None   # model table -- always one of the main tables
    view = None    # (optional) view to load values for the summary etc.
    utable = None  # table to send updates to
    table_alias = ''

    key_extra_fields = []
    fields_cmp = {}

    schema = None
    event_map = None
    state = None
    log = None
    summary = None
    summary_header = None
    dump = None
    dump_header = None
    errors = None
    msg_class = None
    reverse_update = None
    _apply_script = None
    _key = None
    _replace = None
    _replace_on_key_change = False

    def __init__(self,
                 view,
                 key,
                 iclass,
                 ctxid=None,
                 match_src=None,
                 match_pairs=None):
        self.view = view
        self.ndb = view.ndb
        self.sources = view.ndb.sources
        self.ctxid = ctxid
        self.schema = view.ndb.schema
        self.match_src = match_src or tuple()
        self.match_pairs = match_pairs or dict()
        self.changed = set()
        self.iclass = iclass
        self.utable = self.utable or self.table
        self.errors = []
        self.log = self.ndb.log.channel('rtnl_object')
        self.log.debug('init')
        self.state = State()
        self.state.set('invalid')
        self.snapshot_deps = []
        self.load_event = threading.Event()
        self.load_event.set()
        self.lock = threading.Lock()
        self.kspec = self.schema.compiled[self.table]['idx']
        self.knorm = self.schema.compiled[self.table]['norm_idx']
        self.spec = self.schema.compiled[self.table]['all_names']
        self.names = self.schema.compiled[self.table]['norm_names']
        self._apply_script = []
        if isinstance(key, dict):
            self.chain = key.pop('ndb_chain', None)
            create = key.pop('create', False)
        else:
            self.chain = None
            create = False
        ckey = self.complete_key(key)
        if create and ckey is not None:
            raise KeyError('object exists')
        elif not create and ckey is None:
            raise KeyError('object does not exists')
        elif create:
            for name in key:
                self[name] = key[name]
            # FIXME -- merge with complete_key()
            if 'target' not in self:
                self.load_value('target', 'localhost')
        elif ctxid is None:
            self.key = ckey
            self.load_sql()
        else:
            self.key = ckey
            self.load_sql(table=self.table)

    @classmethod
    def adjust_spec(cls, spec):
        return spec

    @classmethod
    def nla2name(self, name):
        return self.msg_class.nla2name(name)

    @classmethod
    def name2nla(self, name):
        return self.msg_class.name2nla(name)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.commit()

    def __hash__(self):
        return id(self)

    def __getitem__(self, key):
        if key in self.match_pairs:
            for src in self.match_src:
                try:
                    return src[self.match_pairs[key]]
                except:
                    pass
        return dict.__getitem__(self, key)

    def __setitem__(self, key, value):
        if self.state == 'system' and key in self.knorm:
            if self._replace_on_key_change:
                self.log.debug('prepare replace for key %s' % (self.key))
                self._replace = type(self)(self.view, self.key)
                self.state.set('replace')
            else:
                raise ValueError('Attempt to change a key field (%s)' % key)
        if key in ('net_ns_fd', 'net_ns_pid'):
            self.state.set('setns')
        if value != self.get(key, None):
            self.changed.add(key)
            dict.__setitem__(self, key, value)

    def fields(self, *argv):
        Fields = collections.namedtuple('Fields', argv)
        return Fields(*[self[key] for key in argv])

    def key_repr(self):
        return repr(self.key)

    @cli.change_pointer
    def create(self, **spec):
        spec['create'] = True
        spec['ndb_chain'] = self
        return self.view[spec]

    @cli.show_result
    def show(self, **kwarg):
        fmt = kwarg.pop('format',
                        kwarg.pop('fmt',
                                  self.view.ndb.config.get('show_format',
                                                           'native')))
        if fmt == 'native':
            return dict(self)
        else:
            out = collections.OrderedDict()
            for key in sorted(self):
                out[key] = self[key]
            return '%s\n' % json.dumps(out, indent=4, separators=(',', ': '))

    def set(self, key, value):
        self[key] = value
        return self

    def wtime(self, itn=1):
        return max(min(itn * 0.1, 1),
                   self.view.ndb._event_queue.qsize() / 10)

    @property
    def etable(self):
        if self.ctxid:
            return '%s_%s' % (self.table, self.ctxid)
        else:
            return self.table

    @property
    def key(self):
        nkey = self._key or {}
        ret = collections.OrderedDict()
        for name in self.kspec:
            kname = self.iclass.nla2name(name)
            if kname in self:
                value = self[kname]
                if value is None and name in nkey:
                    value = nkey[name]
                ret[name] = value
        if len(ret) < len(self.kspec):
            for name in self.key_extra_fields:
                kname = self.iclass.nla2name(name)
                if self.get(kname):
                    ret[name] = self[kname]
        return ret

    @key.setter
    def key(self, k):
        if not isinstance(k, dict):
            return
        for key, value in k.items():
            if value is not None:
                dict.__setitem__(self, self.iclass.nla2name(key), value)

    def register(self):
        #
        # Construct a weakref handler for events.
        #
        # If the referent doesn't exist, raise the
        # exception to remove the handler from the
        # chain.
        #
        def wr_handler(wr, fname, *argv):
            try:
                return getattr(wr(), fname)(*argv)
            except:
                # check if the weakref became invalid
                if wr() is None:
                    raise InvalidateHandlerException()
                raise

        wr = weakref.ref(self)
        for event, fname in self.event_map.items():
            #
            # Do not trust the implicit scope and pass the
            # weakref explicitly via partial
            #
            (self
             .ndb
             .register_handler(event,
                               partial(wr_handler, wr, fname)))

    def snapshot(self, ctxid=None):
        ctxid = ctxid or self.ctxid or id(self)
        if self._replace is None:
            key = self.key
        else:
            key = self._replace.key
        snp = type(self)(self.view, key, ctxid=ctxid)
        self.ndb.schema.save_deps(ctxid, weakref.ref(snp), self.iclass)
        snp.changed = set(self.changed)
        return snp

    def complete_key(self, key):
        self.log.debug('complete key %s from table %s' % (key, self.etable))
        fetch = []
        if isinstance(key, Record):
            key = key._as_dict()
        for name in self.kspec:
            if name not in key:
                fetch.append('f_%s' % name)

        if fetch:
            keys = []
            values = []
            for name, value in key.items():
                nla_name = self.iclass.name2nla(name)
                if nla_name in self.spec:
                    name = nla_name
                if value is not None and name in self.spec:
                    keys.append('f_%s = %s' % (name, self.schema.plch))
                    values.append(value)
            spec = (self
                    .ndb
                    .schema
                    .fetchone('SELECT %s FROM %s WHERE %s' %
                              (' , '.join(fetch),
                               self.etable,
                               ' AND '.join(keys)),
                              values))
            if spec is None:
                self.log.debug('got none')
                return None
            for name, value in zip(fetch, spec):
                key[name[2:]] = value

        self.log.debug('got %s' % key)
        return key

    def rollback(self, snapshot=None):
        if self._replace is not None:
            self.log.debug('rollback replace: %s :: %s'
                           % (self.key, self._replace.key))
            new_replace = type(self)(self.view, self.key)
            new_replace.state.set('remove')
            self.state.set('replace')
            self.update(self._replace)
            self._replace = new_replace
        self.log.debug('rollback: %s' % str(self.state.events))
        snapshot = snapshot or self.last_save
        snapshot.state.set(self.state.get())
        snapshot.apply(rollback=True)
        for link, snp in snapshot.snapshot_deps:
            link.rollback(snapshot=snp)
        return self

    def commit(self):
        if self.chain:
            self.chain.commit()
        self.log.debug('commit: %s' % str(self.state.events))
        # Is it a new object?
        if self.state == 'invalid':
            # Save values, try to apply
            save = dict(self)
            try:
                return self.apply()
            except Exception as e_i:
                # Save the debug info
                e_i.trace = traceback.format_exc()
                # ACHTUNG! The routine doesn't clean up the system
                #
                # Drop all the values and rollback to the initial state
                for key in tuple(self.keys()):
                    del self[key]
                for key in save:
                    dict.__setitem__(self, key, save[key])
                raise e_i

        # Continue with an existing object

        # The snapshot tables in the DB will be dropped as soon as the GC
        # collects the object. But in the case of an exception the `snp`
        # variable will be saved in the traceback, so the tables will be
        # available to debug. If the traceback will be saved somewhere then
        # the tables will never be dropped by the GC, so you can do it
        # manually by `ndb.schema.purge_snapshots()` -- to invalidate all
        # the snapshots and to drop the associated tables.

        # Apply the changes
        try:
            self.apply()
        except Exception as e_c:
            # Rollback in the case of any error
            try:
                self.rollback()
            except Exception as e_r:
                e_c.chain = [e_r]
                if hasattr(e_r, 'chain'):
                    e_c.chain.extend(e_r.chain)
                e_r.chain = None
            raise
        finally:
            (self
             .last_save
             .state
             .set(self.state.get()))
        if self._replace is not None:
            self._replace = None
        return self

    def remove(self):
        with self.lock:
            self.state.set('remove')
            return self

    def check(self):
        state_map = (('invalid', 'system'),
                     ('remove', 'invalid'),
                     ('setns', 'invalid'),
                     ('replace', 'system'))

        self.load_sql()
        self.log.debug('check: %s' % str(self.state.events))

        if self.state.transition() not in state_map:
            self.log.debug('check state: False')
            return False

        if self.changed:
            self.log.debug('check changed: %s' % (self.changed))
            return False

        self.log.debug('check: True')
        return True

    def make_req(self, prime):
        req = dict(prime)
        for key in self.changed:
            req[key] = self[key]
        return req

    def get_count(self):
        conditions = []
        values = []
        for name in self.kspec:
            conditions.append('f_%s = %s' % (name, self.schema.plch))
            values.append(self.get(self.iclass.nla2name(name), None))
        return (self
                .ndb
                .schema
                .fetchone('''
                          SELECT count(*) FROM %s WHERE %s
                          ''' % (self.table,
                                 ' AND '.join(conditions)),
                          values))[0]

    def hook_apply(self, method, **spec):
        pass

    def apply(self, rollback=False):

        # Save the context
        if not rollback and self.state != 'invalid':
            self.last_save = self.snapshot()

        self.log.debug('apply: %s' % str(self.state.events))
        self.load_event.clear()
        self._apply_script_snapshots = []

        # Load the current state
        try:
            self.schema.commit()
        except:
            pass
        self.load_sql(set_state=False)
        if self.state == 'system' and self.get_count() == 0:
            state = self.state.set('invalid')
        else:
            state = self.state.get()

        # Create the request.
        idx_req = dict([(x, self[self.iclass.nla2name(x)]) for x
                        in self.schema.compiled[self.table]['idx']
                        if self.iclass.nla2name(x) in self])
        req = self.make_req(idx_req)
        self.log.debug('apply req: %s' % str(req))
        self.log.debug('apply idx_req: %s' % str(idx_req))
        self.log.debug('apply state: %s' % state)

        method = None
        ignore = tuple()
        #
        if state in ('invalid', 'replace'):
            req = dict([x for x in self.items() if x[1] is not None])
            for l_key, r_key in self.match_pairs.items():
                for src in self.match_src:
                    try:
                        req[l_key] = src[r_key]
                        break
                    except:
                        pass
            method = 'add'
            ignore = {errno.EEXIST: 'set'}
        elif state == 'system':
            method = 'set'
        elif state == 'setns':
            method = 'set'
            ignore = {errno.ENODEV: None}
        elif state == 'remove':
            method = 'del'
            req = idx_req
            ignore = {errno.ENODEV: None,         # interfaces
                      errno.ESRCH: None,          # routes
                      errno.EADDRNOTAVAIL: None}  # addresses
        else:
            raise Exception('state transition not supported')

        for itn in range(20):
            try:
                self.log.debug('run %s (%s)' % (method, req))
                (self
                 .sources[self['target']]
                 .api(self.api, method, **req))
                (self
                 .hook_apply(method, **req))
            except NetlinkError as e:
                (self
                 .log
                 .debug('error: %s' % e))
                if e.code in ignore:
                    self.log.debug('ignore error %s for %s' % (e.code, self))
                    if ignore[e.code] is not None:
                        self.log.debug('run fallback %s (%s)'
                                       % (ignore[e.code], req))
                        try:
                            (self
                             .sources[self['target']]
                             .api(self.api, ignore[e.code], **req))
                        except NetlinkError:
                            pass
                else:
                    raise e

            wtime = self.wtime(itn)
            mqsize = self.view.ndb._event_queue.qsize()
            nqsize = self.schema.stats.get(self['target']).qsize
            self.log.debug('stats: apply %s {'
                           'objid %s, wtime %s, '
                           'mqsize %s, nqsize %s'
                           '}' % (method, id(self), wtime, mqsize, nqsize))
            if self.check():
                self.log.debug('checked')
                break
            self.log.debug('check failed')
            self.load_event.wait(wtime)
            self.load_event.clear()
        else:
            self.log.debug('stats: %s apply %s fail' % (id(self), method))
            raise Exception('lost sync in apply()')

        self.log.debug('stats: %s pass' % (id(self)))
        #
        if state == 'replace':
            self._replace.remove()
            self._replace.apply()
        #
        if rollback:
            #
            # Iterate all the snapshot tables and collect the diff
            for (table, cls) in self.view.classes.items():
                if issubclass(type(self), cls) or \
                        issubclass(cls, type(self)):
                    continue
                # comprare the tables
                diff = (self
                        .ndb
                        .schema
                        .fetch('''
                               SELECT * FROM %s_%s
                                 EXCEPT
                               SELECT * FROM %s
                               '''
                               % (table, self.ctxid, table)))
                for record in diff:
                    record = dict(zip((self
                                       .schema
                                       .compiled[table]['all_names']),
                                      record))
                    key = dict([x for x in record.items()
                                if x[0] in self.schema.compiled[table]['idx']])
                    obj = self.view.get(key, table)
                    obj.load_sql(ctxid=self.ctxid)
                    obj.state.set('invalid')
                    try:
                        obj.apply()
                    except Exception as e:
                        self.errors.append((time.time(), obj, e))
        else:
            for op, argv, kwarg in self._apply_script:
                ret = op(*argv, **kwarg)
                if isinstance(ret, Exception):
                    raise ret
                elif ret is not None:
                    self._apply_script_snapshots.append(ret)
            self._apply_script = []
        return self

    def update(self, data):
        for key, value in data.items():
            self.load_value(key, value)

    def load_value(self, key, value):
        if key not in self.changed:
            dict.__setitem__(self, key, value)
        elif self.get(key) == value:
            self.changed.remove(key)
        elif key in self.fields_cmp and self.fields_cmp[key](self, value):
            self.changed.remove(key)

    def load_sql(self, table=None, ctxid=None, set_state=True):

        if not self.key:
            return

        if table is None:
            if ctxid is None:
                table = self.etable
            else:
                table = '%s_%s' % (self.table, ctxid)
        keys = []
        values = []

        for name, value in self.key.items():
            keys.append('f_%s = %s' % (name, self.schema.plch))
            values.append(value)

        spec = (self
                .ndb
                .schema
                .fetchone('SELECT * FROM %s WHERE %s' %
                          (table, ' AND '.join(keys)), values))
        self.log.debug('load_sql: %s' % str(spec))
        if set_state:
            with self.lock:
                if spec is None:
                    if self.state != 'invalid':
                        # No such object (anymore)
                        self.state.set('invalid')
                        self.changed = set()
                elif self.state not in ('remove', 'setns'):
                    self.update(dict(zip(self.names, spec)))
                    self.state.set('system')
        return spec

    def load_rtnlmsg(self, target, event):
        # TODO: partial match (object rename / restore)
        # ...

        # full match
        for name in self.knorm:
            value = self.get(name)
            if name == 'target':
                if value != target:
                    return
            elif name == 'tflags':
                continue
            elif value != (event.get_attr(name) or event.get(name)):
                return

        self.log.debug('load_rtnl: %s' % str(event.get('header')))
        if event['header'].get('type', 0) % 2:
            self.state.set('invalid')
            self.changed = set()
        else:
            self.load_sql()
        self.load_event.set()