Esempio n. 1
0
 def initialize(self, service):
     self.service = service
     self.logger = PrefixLoggerAdapter(self.service.logger, self.name)
     self.tz = pytz.timezone(config.timezone)
Esempio n. 2
0
class Job(object):
    # Unique job name
    name = None
    # Set to False when job is disabled
    enabled = True
    # Model/Document class referenced by key
    model = None
    # Use model.get_by_id for dereference
    use_get_by_id = False
    # Group name. Only one job from group can be started
    # if is not None
    group_name = None
    # Context format version
    # None - do not store context
    # Set to version number otherwise
    # Bump to next number on incompatible context changes
    context_version = None
    #
    context_cache_key = "jobctx-%(name)s-%(pool)s-%(job_id)s"
    # Collection attributes
    ATTR_ID = "_id"
    ATTR_TS = "ts"
    ATTR_CLASS = "jcls"
    ATTR_STATUS = "s"
    ATTR_TIMEOUT = "timeout"
    ATTR_KEY = "key"
    ATTR_DATA = "data"
    ATTR_LAST = "last"  # timestamp of last run
    ATTR_LAST_STATUS = "ls"  # last completion status
    ATTR_LAST_DURATION = "ldur"  # last job duration, in success
    ATTR_LAST_SUCCESS = "st"  # last success timestamp
    ATTR_RUNS = "runs"  # Number of runs
    ATTR_MAX_RUNS = "mruns"  # Maximum allowed number of runs
    ATTR_FAULTS = "f"  # Amount of sequental faults
    ATTR_OFFSET = "o"  # Random offset [0 .. 1]
    ATTR_SAMPLE = "sample"  # Span sample

    # Job states
    S_WAIT = "W"  # Waiting to run
    S_RUN = "R"  # Running
    S_STOP = "S"  # Stopped by operator
    S_DISABLED = "D"  # Disabled by system
    S_SUSPEND = "s"  # Suspended by system

    # Exit statuses
    E_SUCCESS = "S"  # Completed successfully
    E_FAILED = "F"  # Failed
    E_EXCEPTION = "X"  # Terminated by exception
    E_DEFERRED = "D"  # Cannot be run
    E_DEREFERENCE = "d"  # Cannot be dereferenced
    E_RETRY = "r"  # Forcefully retried

    STATUS_MAP = {
        E_SUCCESS: "SUCCESS",
        E_FAILED: "FAILED",
        E_EXCEPTION: "EXCEPTION",
        E_DEFERRED: "DEFERRED",
        E_DEREFERENCE: "DEREFERENCE",
        E_RETRY: "RETRY"
    }

    class JobFailed(Exception):
        pass

    # List of exceptions to be considered failed jobs
    failed_exceptions = (JobFailed, )

    def __init__(self, scheduler, attrs):
        """
        :param scheduler: Scheduler instance
        :param attrs: dict containing record from scheduler's collection
        """
        self.scheduler = scheduler
        self.attrs = attrs
        self.object = None
        self.start_time = None
        self.duration = None
        self.logger = PrefixLoggerAdapter(scheduler.logger,
                                          self.get_display_key())
        self.context = {}

    def load_context(self, data):
        self.context = data or {}
        self.init_context()

    def init_context(self):
        """
        Perform context initialization
        """
        pass

    @tornado.gen.coroutine
    def run(self):
        with Span(server=self.scheduler.name,
                  service=self.attrs[self.ATTR_CLASS],
                  sample=self.attrs.get(self.ATTR_SAMPLE, 0),
                  in_label=self.attrs.get(self.ATTR_KEY, "")):
            self.start_time = time.time()
            if self.is_retries_exceeded():
                self.logger.info("[%s|%s] Retries exceeded. Remove job",
                                 self.name, self.attrs[Job.ATTR_ID])
                self.remove_job()
                return
            self.logger.info(
                "[%s] Starting at %s (Lag %.2fms)", self.name,
                self.scheduler.scheduler_id,
                total_seconds(datetime.datetime.now() -
                              self.attrs[self.ATTR_TS]) * 1000.0)
            # Run handler
            status = self.E_EXCEPTION
            delay = None
            with Span(service="job.dereference"):
                try:
                    ds = self.dereference()
                    can_run = self.can_run()
                except Exception as e:
                    self.logger.error("Unknown error during dereference: %s",
                                      e)
                    ds = None
                    can_run = False

            if ds:
                with Span(service="job.run"):
                    if can_run:
                        try:
                            data = self.attrs.get(self.ATTR_DATA) or {}
                            result = self.handler(**data)
                            if tornado.gen.is_future(result):
                                # Wait for future
                                result = yield result
                            status = self.E_SUCCESS
                        except RetryAfter as e:
                            self.logger.info("Retry after %ss: %s", e.delay, e)
                            status = self.E_RETRY
                            delay = e.delay
                        except self.failed_exceptions:
                            status = self.E_FAILED
                        except Exception:
                            error_report()
                            status = self.E_EXCEPTION
                    else:
                        self.logger.info("Deferred")
                        status = self.E_DEFERRED
            elif ds is not None:
                self.logger.info("Cannot dereference")
                status = self.E_DEREFERENCE
            self.duration = time.time() - self.start_time
            self.logger.info("Completed. Status: %s (%.2fms)",
                             self.STATUS_MAP.get(status, status),
                             self.duration * 1000)
            # Schedule next run
            if delay is None:
                with Span(service="job.schedule_next"):
                    self.schedule_next(status)
            else:
                with Span(service="job.schedule_retry"):
                    # Retry
                    if self.context_version:
                        ctx = self.context or None
                        ctx_key = self.get_context_cache_key()
                    else:
                        ctx = None
                        ctx_key = None
                    self.scheduler.set_next_run(
                        self.attrs[self.ATTR_ID],
                        status=status,
                        ts=datetime.datetime.now() +
                        datetime.timedelta(seconds=delay),
                        duration=self.duration,
                        context_version=self.context_version,
                        context=ctx,
                        context_key=ctx_key)

    def handler(self, **kwargs):
        """
        Job handler, must be sublclassed
        """
        raise NotImplementedError()

    def get_defererence_query(self):
        """
        Get dereference query condition.
        Called by dereference()
        :return: dict or None
        """
        return {"pk": self.attrs[self.ATTR_KEY]}

    def dereference(self):
        """
        Retrieve referenced object from database
        """
        if self.model and self.use_get_by_id:
            self.object = self.model.get_by_id(self.attrs[self.ATTR_KEY])
            if not self.object:
                return False
        elif self.model:
            q = self.get_defererence_query()
            if q is None:
                return False
            try:
                # Resolve object
                self.object = self.model.objects.get(**q)
            except self.model.DoesNotExist:
                return False
        # Adjust logging
        self.logger.set_prefix(
            "%s][%s][%s" %
            (self.scheduler.name, self.name, self.get_display_key()))
        return True

    def get_display_key(self):
        """
        Return dereferenced key name
        """
        if self.object:
            return unicode(self.object)
        else:
            return self.attrs[self.ATTR_KEY]

    def can_run(self):
        """
        Check wrether the job can be launched
        :return:
        """
        return True

    def get_group(self):
        return self.group_name

    def remove_job(self):
        """
        Remove job from schedule
        """
        self.scheduler.remove_job_by_id(self.attrs[self.ATTR_ID])

    def schedule_next(self, status):
        """
        Schedule next run depending on status.
        Drop job by default
        """
        self.remove_job()

    @classmethod
    def submit(cls,
               scheduler,
               name=None,
               key=None,
               data=None,
               pool=None,
               ts=None,
               delta=None,
               keep_ts=False):
        """
        Submit new job or change schedule for existing one
        :param scheduler: scheduler name
        :param name: Job full name
        :param key: Job key
        :param data: Job data
        :param pool: Pool name
        :param ts: Next run timestamp
        :param delta: Run after *delta* seconds
        :param keep_ts: Do not touch timestamp of existing jobs,
            set timestamp only for created jobs
        """
        from .scheduler import Scheduler
        scheduler = Scheduler(name=scheduler, pool=pool)
        scheduler.submit(name,
                         key=key,
                         data=data,
                         ts=ts,
                         delta=delta,
                         keep_ts=keep_ts)

    @classmethod
    def remove(cls, scheduler, name=None, key=None, pool=None):
        from .scheduler import Scheduler
        scheduler = Scheduler(name=scheduler, pool=pool)
        scheduler.remove_job(name, key=key)

    @classmethod
    def get_job_data(cls, scheduler, jcls, key=None, pool=None):
        from .scheduler import Scheduler
        scheduler = Scheduler(name=scheduler, pool=pool)
        return scheduler.get_collection().find_one({
            Job.ATTR_CLASS: jcls,
            Job.ATTR_KEY: key
        })

    def get_context_cache_key(self):
        ctx = {
            "name": self.scheduler.name,
            "pool": self.scheduler.pool or "global",
            "job_id": self.attrs[self.ATTR_ID]
        }
        return self.context_cache_key % ctx

    @classmethod
    def retry_after(cls, delay, msg=""):
        """
        Must be called from job handler to deal with temporary problems.
        Current job handler will be terminated and job
        will be scheduled after *delay* seconds
        :param delay: Delay in seconds
        :param msg: Informational message
        :return:
        """
        raise RetryAfter(msg, delay=delay)

    def is_retries_exceeded(self):
        """
        Check if maximal amount of retries is exceeded
        :return:
        """
        runs = self.attrs.get(Job.ATTR_RUNS, 0)
        max_runs = self.attrs.get(Job.ATTR_MAX_RUNS, 0)
        return max_runs and runs >= max_runs

    @staticmethod
    def get_next_timestamp(interval, offset=0.0, ts=None):
        """
        Calculate next timestamp
        :param interval:
        :param offset:
        :param ts: current timestamp
        :return: datetime object
        """
        if not ts:
            ts = time.time()
        if ts and isinstance(ts, datetime.datetime):
            ts = time.mktime(
                ts.timetuple()) + float(ts.microsecond) / 1000000.0
        # Get start of current interval
        si = ts // interval * interval
        # Shift to offset
        si += offset * interval
        # Shift to interval if in the past
        if si <= ts:
            si += interval
        return datetime.datetime.fromtimestamp(si)
Esempio n. 3
0
class HTTP(object):
    CONNECT_TIMEOUT = config.http_client.connect_timeout
    REQUEST_TIMEOUT = config.http_client.request_timeout

    class HTTPError(NOCError):
        default_code = ERR_HTTP_UNKNOWN

    def __init__(self, script):
        self.script = script
        self.logger = PrefixLoggerAdapter(script.logger, "http")

    def get_url(self, path):
        address = self.script.credentials["address"]
        port = self.script.credentials.get("http_port")
        if port:
            address += ":%s" % port
        proto = self.script.credentials.get("http_protocol", "http")
        return "%s://%s%s" % (proto, address, path)

    def get(self, path, headers=None, cached=False, json=False, eof_mark=None):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        """
        self.logger.debug("GET %s", path)
        if cached:
            cache_key = "get_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        code, headers, result = fetch_sync(self.get_url(path),
                                           headers=headers,
                                           request_timeout=60,
                                           follow_redirects=True,
                                           allow_proxy=False,
                                           validate_cert=False,
                                           eof_mark=eof_mark)
        # pylint: disable=superfluous-parens
        if not (200 <= code <= 299):  # noqa
            raise self.HTTPError(msg="HTTP Error (%s)" % result[:256],
                                 code=code)
        if json:
            try:
                result = ujson.loads(result)
            except ValueError as e:
                raise self.HTTPError("Failed to decode JSON: %s", e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def post(self,
             path,
             data,
             headers=None,
             cached=False,
             json=False,
             eof_mark=None):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        """
        self.logger.debug("POST %s %s", path, data)
        if cached:
            cache_key = "post_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        code, headers, result = fetch_sync(self.get_url(path),
                                           method="POST",
                                           headers=headers,
                                           request_timeout=60,
                                           follow_redirects=True,
                                           allow_proxy=False,
                                           validate_cert=False,
                                           eof_mark=eof_mark)
        # pylint: disable=superfluous-parens
        if not (200 <= code <= 299):  # noqa
            raise self.HTTPError(msg="HTTP Error (%s)" % result[:256],
                                 code=code)
        if json:
            try:
                return ujson.loads(result)
            except ValueError as e:
                raise self.HTTPError(msg="Failed to decode JSON: %s" % e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def close(self):
        pass
Esempio n. 4
0
class BaseLoader(object):
    """
    Import directory structure:
    var/
        import/
            <system name>/
                <loader name>/
                    import.csv[.gz]  -- state to load, can have .gz extension
                    mappings.csv -- ID mappings
                    archive/
                        import-YYYY-MM-DD-HH-MM-SS.csv.gz -- imported state

    Import file format: CSV, unix end of lines, UTF-8, comma-separated
    First column - record id in the terms of connected system,
    other columns must be defined in *fields* variable.

    File must be sorted by first field either as string or as numbers,
    sort order must not be changed.

    mappings.csv - CSV, unix end of lines, UTF-8 comma separated
    mappings of ID between NOC and remote system. Populated by loader
    automatically.

    :param fields: List of either field names or tuple of
        (field name, related loader name)
    """

    # Loader name
    name = None
    # Loader model
    model = None
    # Mapped fields
    mapped_fields = {}

    fields = []

    # List of tags to add to the created records
    tags = []

    PREFIX = config.path.etl_import
    rx_archive = re.compile(r"^import-\d{4}(?:-\d{2}){5}.csv.gz$")

    # Discard records which cannot be dereferenced
    discard_deferred = False
    # Ignore auto-generated unique fields
    ignore_unique = {"bi_id"}

    REPORT_INTERVAL = 1000

    class Deferred(Exception):
        pass

    def __init__(self, chain):
        self.chain = chain
        self.system = chain.system
        self.logger = PrefixLoggerAdapter(
            logger, "%s][%s" % (self.system.name, self.name))
        self.disable_mappings = False
        self.import_dir = os.path.join(self.PREFIX, self.system.name,
                                       self.name)
        self.archive_dir = os.path.join(self.import_dir, "archive")
        self.mappings_path = os.path.join(self.import_dir, "mappings.csv")
        self.mappings = {}
        self.new_state_path = None
        self.c_add = 0
        self.c_change = 0
        self.c_delete = 0
        # Build clean map
        self.clean_map = dict(
            (n, self.clean_str)
            for n in self.fields)  # field name -> clean function
        self.pending_deletes = []  # (id, string)
        self.reffered_errors = []  # (id, string)
        if self.is_document:
            import mongoengine.errors

            unique_fields = [
                f.name for f in six.itervalues(self.model._fields)
                if f.unique and f.name not in self.ignore_unique
            ]
            self.integrity_exception = mongoengine.errors.NotUniqueError
        else:
            # Third-party modules
            import django.db.utils

            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
                and f.name not in self.ignore_unique
            ]
            self.integrity_exception = django.db.utils.IntegrityError
        if unique_fields:
            self.unique_field = unique_fields[0]
        else:
            self.unique_field = None

    @property
    def is_document(self):
        """
        Returns True if model is Document, False - if Model
        """
        return hasattr(self.model, "_fields")

    def load_mappings(self):
        """
        Load mappings file
        """
        if self.model:
            if self.is_document:
                self.update_document_clean_map()
            else:
                self.update_model_clean_map()
        if not os.path.exists(self.mappings_path):
            return
        self.logger.info("Loading mappings from %s", self.mappings_path)
        with open(self.mappings_path) as f:
            reader = csv.reader(f)
            for k, v in reader:
                self.mappings[self.clean_str(k)] = v
        self.logger.info("%d mappings restored", len(self.mappings))

    def get_new_state(self):
        """
        Returns file object of new state, or None when not present
        """
        # Try import.csv
        path = os.path.join(self.import_dir, "import.csv")
        if os.path.isfile(path):
            logger.info("Loading from %s", path)
            self.new_state_path = path
            return open(path, "r")
        # Try import.csv.gz
        path += ".gz"
        if os.path.isfile(path):
            logger.info("Loading from %s", path)
            self.new_state_path = path
            return gzip.GzipFile(path, "r")
        # No data to import
        return None

    def get_current_state(self):
        """
        Returns file object of current state or None
        """
        self.load_mappings()
        if not os.path.isdir(self.archive_dir):
            self.logger.info("Creating archive directory: %s",
                             self.archive_dir)
            try:
                os.mkdir(self.archive_dir)
            except OSError as e:
                self.logger.error("Failed to create directory: %s (%s)",
                                  self.archive_dir, e)
                # @todo: Die
        if os.path.isdir(self.archive_dir):
            fn = sorted(f for f in os.listdir(self.archive_dir)
                        if self.rx_archive.match(f))
        else:
            fn = []
        if fn:
            path = os.path.join(self.archive_dir, fn[-1])
            logger.info("Current state from %s", path)
            return gzip.GzipFile(path, "r")
        # No current state
        return six.StringIO("")

    def diff(self, old, new):
        """
        Compare old and new CSV files and yield pair of matches
        * old, new -- when changed
        * old, None -- when removed
        * None, new -- when added
        """
        def getnext(g):
            try:
                return next(g)
            except StopIteration:
                return None

        o = getnext(old)
        n = getnext(new)
        while o or n:
            if not o:
                # New
                yield None, n
                n = getnext(new)
            elif not n:
                # Removed
                yield o, None
                o = getnext(old)
            else:
                if n[0] == o[0]:
                    # Changed
                    if n != o:
                        yield o, n
                    n = getnext(new)
                    o = getnext(old)
                elif n[0] < o[0]:
                    # Added
                    yield None, n
                    n = getnext(new)
                else:
                    # Removed
                    yield o, None
                    o = getnext(old)

    def load(self):
        """
        Import new data
        """
        self.logger.info("Importing")
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            self.load_mappings()
            return
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        deferred_add = []
        deferred_change = []
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                try:
                    self.on_add(n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_add += [n]
            elif o and n is None:
                self.on_delete(o)
            else:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_change += [(o, n)]
            rn = self.c_add + self.c_change + self.c_delete
            if rn > 0 and rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Add deferred records
        while len(deferred_add):
            nd = []
            for row in deferred_add:
                try:
                    self.on_add(row)
                except self.Deferred:
                    nd += [row]
            if len(nd) == len(deferred_add):
                raise Exception("Unable to defer references")
            deferred_add = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Change deferred records
        while len(deferred_change):
            nd = []
            for o, n in deferred_change:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    nd += [(o, n)]
            if len(nd) == len(deferred_change):
                raise Exception("Unable to defer references")
            deferred_change = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)

    def find_object(self, v):
        """
        Find object by remote system/remote id
        :param v:
        :return:
        """
        if not v.get("remote_system") or not v.get("remote_id"):
            self.logger.warning("RS or RID not found")
            return None
        find_query = {
            "remote_system": v.get("remote_system"),
            "remote_id": v.get("remote_id")
        }
        try:
            return self.model.objects.get(**find_query)
        except self.model.MultipleObjectsReturned:
            if self.unique_field:
                find_query[self.unique_field] = v.get(self.unique_field)
                r = self.model.objects.filter(**find_query)
                if not r:
                    r = self.model.objects.filter(**find_query)
                return list(r)[-1]
            raise self.model.MultipleObjectsReturned
        except self.model.DoesNotExist:
            return None

    def create_object(self, v):
        """
        Create object with attributes. Override to save complex
        data structures
        """
        self.logger.debug("Create object")
        for k, nv in six.iteritems(v):
            if k == "tags":
                # Merge tags
                nv = sorted("%s:%s" % (self.system.name, x) for x in nv)
                v[k] = nv
        o = self.model(**v)
        try:
            o.save()
        except self.integrity_exception as e:
            self.logger.warning("Integrity error: %s", e)
            assert self.unique_field
            if not self.is_document:
                from django.db import connection

                connection._rollback()
            # Fallback to change object
            o = self.model.objects.get(
                **{self.unique_field: v[self.unique_field]})
            for k, nv in six.iteritems(v):
                setattr(o, k, nv)
            o.save()
        return o

    def change_object(self, object_id, v):
        """
        Change object with attributes
        """
        self.logger.debug("Changed object")
        # See: https://code.getnoc.com/noc/noc/merge_requests/49
        try:
            o = self.model.objects.get(pk=object_id)
        except self.model.DoesNotExist:
            self.logger.error("Cannot change %s:%s: Does not exists",
                              self.name, object_id)
            return None
        for k, nv in six.iteritems(v):
            if k == "tags":
                # Merge tags
                ov = o.tags or []
                nv = sorted([
                    x for x in ov if not (x.startswith(self.system.name + ":")
                                          or x == "remote:deleted")
                ] + ["%s:%s" % (self.system.name, x) for x in nv])
            setattr(o, k, nv)
        o.save()
        return o

    def on_add(self, row):
        """
        Create new record
        """
        self.logger.debug("Add: %s", ";".join(row))
        v = self.clean(row)
        # @todo: Check record is already exists
        if self.fields[0] in v:
            del v[self.fields[0]]
        if hasattr(self.model, "remote_system"):
            o = self.find_object(v)
        else:
            o = None
        if o:
            self.c_change += 1
            # Lost&found object with same remote_id
            self.logger.debug("Lost and Found object")
            vv = {
                "remote_system": v["remote_system"],
                "remote_id": v["remote_id"]
            }
            for fn, nv in six.iteritems(v):
                if fn in vv:
                    continue
                if getattr(o, fn) != nv:
                    vv[fn] = nv
            self.change_object(o.id, vv)
            # Restore mappings
            self.set_mappings(row[0], o.id)
        else:
            self.c_add += 1
            o = self.create_object(v)
            self.set_mappings(row[0], o.id)

    def on_change(self, o, n):
        """
        Create change record
        """
        self.logger.debug("Change: %s", ";".join(n))
        self.c_change += 1
        v = self.clean(n)
        vv = {"remote_system": v["remote_system"], "remote_id": v["remote_id"]}
        for fn, (ov, nv) in zip(self.fields[1:], zip_longest(o[1:], n[1:])):
            if ov != nv:
                self.logger.debug("   %s: %s -> %s", fn, ov, nv)
                vv[fn] = v[fn]
        if n[0] in self.mappings:
            self.change_object(self.mappings[n[0]], vv)
        else:
            self.logger.error("Cannot map id '%s'. Skipping.", n[0])

    def on_delete(self, row):
        """
        Delete record
        """
        self.pending_deletes += [(row[0], ";".join(row))]

    def purge(self):
        """
        Perform pending deletes
        """
        for r_id, msg in reversed(self.pending_deletes):
            self.logger.debug("Delete: %s", msg)
            self.c_delete += 1
            try:
                obj = self.model.objects.get(pk=self.mappings[r_id])
                obj.delete()
            except ValueError as e:  # Reffered Error
                self.logger.error("%s", str(e))
                self.reffered_errors += [(r_id, msg)]
            except self.model.DoesNotExist:
                pass  # Already deleted
        self.pending_deletes = []

    def save_state(self):
        """
        Save current state
        """
        if not self.new_state_path:
            return
        self.logger.info("Summary: %d new, %d changed, %d removed", self.c_add,
                         self.c_change, self.c_delete)
        self.logger.info("Error delete by reffered: %s",
                         "\n".join(self.reffered_errors))
        t = time.localtime()
        archive_path = os.path.join(
            self.archive_dir,
            "import-%04d-%02d-%02d-%02d-%02d-%02d.csv.gz" % tuple(t[:6]))
        self.logger.info("Moving %s to %s", self.new_state_path, archive_path)
        if self.new_state_path.endswith(".gz"):
            # Simply move the file
            shutil.move(self.new_state_path, archive_path)
        else:
            # Compress the file
            self.logger.info("Compressing")
            with open(self.new_state_path, "r") as s:
                with gzip.open(archive_path, "w") as d:
                    d.write(s.read())
            os.unlink(self.new_state_path)
        self.logger.info("Saving mappings to %s", self.mappings_path)
        mdata = "\n".join("%s,%s" % (k, self.mappings[k])
                          for k in sorted(self.mappings))
        safe_rewrite(self.mappings_path, mdata)

    def clean(self, row):
        """
        Cleanup row and return a dict of field name -> value
        """
        r = dict((k, self.clean_map[k](v)) for k, v in zip(self.fields, row))
        # Fill integration fields
        r["remote_system"] = self.system.remote_system
        r["remote_id"] = self.clean_str(row[0])
        return r

    def clean_str(self, value):
        if value:
            if isinstance(value, str):
                return smart_text(value)
            elif not isinstance(value, six.string_types):
                return str(value)
            else:
                return value
        else:
            return None

    def clean_map_str(self, mappings, value):
        value = self.clean_str(value)
        if self.disable_mappings and not mappings:
            return value
        elif value:
            try:
                value = mappings[value]
            except KeyError:
                self.logger.warning("Deferred. Unknown map value: %s", value)
                raise self.Deferred
        return value

    def clean_bool(self, value):
        if value == "":
            return None
        try:
            return int(value) != 0
        except ValueError:
            pass
        value = value.lower()
        return value in ("t", "true", "y", "yes")

    def clean_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = mappings[value]
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def clean_int_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = int(mappings[value])
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def set_mappings(self, rv, lv):
        self.logger.debug("Set mapping remote: %s, local: %s", rv, lv)
        self.mappings[str(rv)] = str(lv)

    def update_document_clean_map(self):
        from mongoengine.fields import BooleanField, ReferenceField
        from noc.core.mongo.fields import PlainReferenceField, ForeignKeyField

        for fn, ft in six.iteritems(self.model._fields):
            if fn not in self.clean_map:
                continue
            if isinstance(ft, BooleanField):
                self.clean_map[fn] = self.clean_bool
            elif isinstance(ft, (PlainReferenceField, ReferenceField)):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif isinstance(ft, ForeignKeyField):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_int_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif fn in self.mapped_fields:
                self.clean_map[fn] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[fn]))

    def update_model_clean_map(self):
        from django.db.models import BooleanField, ForeignKey
        from noc.core.model.fields import DocumentReferenceField

        for f in self.model._meta.fields:
            if f.name not in self.clean_map:
                continue
            if isinstance(f, BooleanField):
                self.clean_map[f.name] = self.clean_bool
            elif isinstance(f, DocumentReferenceField):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.document,
                    )
            elif isinstance(f, ForeignKey):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.remote_field.model,
                    )
            elif f.name in self.mapped_fields:
                self.clean_map[f.name] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[f.name]))

    def check(self, chain):
        self.logger.info("Checking")
        # Get constraints
        if self.is_document:
            # Document
            required_fields = [
                f.name for f in six.itervalues(self.model._fields)
                if f.required or f.unique
            ]
            unique_fields = [
                f.name for f in six.itervalues(self.model._fields) if f.unique
            ]
        else:
            # Model
            required_fields = [
                f.name for f in self.model._meta.fields if not f.blank
            ]
            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
            ]
        if not required_fields and not unique_fields:
            self.logger.info("Nothing to check, skipping")
            return 0
        # Prepare data
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            return 0
        new_state = csv.reader(ns)
        r_index = set(
            self.fields.index(f) for f in required_fields if f in self.fields)
        u_index = set(
            self.fields.index(f) for f in unique_fields
            if f not in self.ignore_unique)
        m_index = set(self.fields.index(f) for f in self.mapped_fields)
        uv = set()
        m_data = {}  # field_number -> set of mapped ids
        # Load mapped ids
        for f in self.mapped_fields:
            line = chain.get_loader(self.mapped_fields[f])
            ls = line.get_new_state()
            if not ls:
                ls = line.get_current_state()
            ms = csv.reader(ls)
            m_data[self.fields.index(f)] = set(row[0] for row in ms)
        # Process data
        n_errors = 0
        for row in new_state:
            lr = len(row)
            # Check required fields
            for i in r_index:
                if not row[i]:
                    self.logger.error(
                        "ERROR: Required field #%d(%s) is missed in row: %s",
                        i,
                        self.fields[i],
                        ",".join(row),
                    )
                    n_errors += 1
                    continue
            # Check unique fields
            for i in u_index:
                v = row[i]
                if (i, v) in uv:
                    self.logger.error(
                        "ERROR: Field #%d(%s) value is not unique: %s",
                        i,
                        self.fields[i],
                        ",".join(row),
                    )
                    n_errors += 1
                else:
                    uv.add((i, v))
            # Check mapped fields
            for i in m_index:
                if i >= lr:
                    continue
                v = row[i]
                if v and v not in m_data[i]:
                    self.logger.error(
                        "ERROR: Field #%d(%s) == '%s' refers to non-existent record: %s",
                        i,
                        self.fields[i],
                        row[i],
                        ",".join(row),
                    )
                    n_errors += 1
        if n_errors:
            self.logger.info("%d errors found", n_errors)
        else:
            self.logger.info("No errors found")
        return n_errors

    def check_diff(self):
        def dump(cmd, row):
            print("%s %s" % (cmd, ",".join(row)))

        print("--- %s.%s" % (self.chain.system.name, self.name))
        ns = self.get_new_state()
        if not ns:
            return
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                dump("+", n)
            elif o and n is None:
                dump("-", o)
            else:
                dump("/", o)
                dump("\\", n)

    def check_diff_summary(self):
        i, u, d = 0, 0, 0
        ns = self.get_new_state()
        if not ns:
            return i, u, d
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                i += 1
            elif o and n is None:
                d += 1
            else:
                u += 1
        return i, u, d
Esempio n. 5
0
File: base.py Progetto: ewwwcha/noc
 def __init__(self):
     self.logger = PrefixLoggerAdapter(logger, self.name)
     self.classes = {}
     self.lock = threading.Lock()
     self.all_classes = set()
Esempio n. 6
0
class MODiscoveryJob(PeriodicJob):
    model = ManagedObject
    use_get_by_id = True
    use_offset = True
    # Name of umbrella class to cover discovery problems
    umbrella_cls = None
    # Job families
    is_box = False
    is_periodic = False

    def __init__(self, *args, **kwargs):
        super(MODiscoveryJob, self).__init__(*args, **kwargs)
        self.out_buffer = StringIO()
        self.logger = PrefixLoggerAdapter(self.logger,
                                          "",
                                          target=self.out_buffer)
        self.check_timings = []
        self.problems = []
        self.caps = None
        self.has_fatal_error = False
        self.service = self.scheduler.service
        # Additional artefacts can be passed between checks in one session
        self.artefacts = {}

    def schedule_next(self, status):
        if self.check_timings:
            self.logger.info(
                "Timings: %s", ", ".join("%s = %.2fms" % (n, t * 1000)
                                         for n, t in self.check_timings))
        super(MODiscoveryJob, self).schedule_next(status)
        # Update alarm statuses
        self.update_alarms()
        # Write job log
        key = "discovery-%s-%s" % (self.attrs[self.ATTR_CLASS],
                                   self.attrs[self.ATTR_KEY])
        problems = {}
        for p in self.problems:
            if p["check"] in problems and p["path"]:
                problems[p["check"]][p["path"]] = p["message"]
            elif p["check"] in problems and not p["path"]:
                # p["path"] == ""
                problems[p["check"]][p["path"]] += "; %s" % p["message"]
            else:
                problems[p["check"]] = {p["path"]: p["message"]}
        get_db()["noc.joblog"].update({"_id": key}, {
            "$set": {
                "log": bson.Binary(zlib.compress(self.out_buffer.getvalue())),
                "problems": problems
            }
        },
                                      upsert=True)

    def can_run(self):
        # @todo: Make configurable
        os = self.object.get_status()
        if not os:
            self.logger.info("Object ping Fail, Job will not run")
        return self.object.is_managed and os

    @contextlib.contextmanager
    def check_timer(self, name):
        t = perf_counter()
        yield
        self.check_timings += [(name, perf_counter() - t)]

    def set_problem(self,
                    check=None,
                    alarm_class=None,
                    path=None,
                    message=None,
                    fatal=False):
        """
        Set discovery problem
        :param check: Check name
        :param alarm_class: Alarm class instance or name
        :param path: Additional path
        :param message: Text message
        :param fatal: True if problem is fatal and all following checks
            must be disabled
        :return:
        """
        self.problems += [{
            "check": check,
            "alarm_class": alarm_class,
            # in MongoDB Key must be string
            "path": str(path) if path else "",
            "message": message,
            "fatal": fatal
        }]
        if fatal:
            self.has_fatal_error = True

    def get_caps(self):
        """
        Return object's capabilities
        :return:
        """
        if self.caps is None:
            self.caps = self.object.get_caps()
        return self.caps

    def update_caps(self, caps, source):
        self.caps = self.object.update_caps(caps, source=source)

    def allow_sessions(self):
        r = self.object.can_cli_session()
        if r:
            self.object.get_profile().allow_cli_session(None, None)
        return r

    def update_umbrella(self, umbrella_cls, details):
        """
        Update umbrella alarm status for managed object

        :param umbrella_cls: Umbrella alarm class (AlarmClass instance)
        :param details: List of dicts, containing
            * alarm_class - Detail alarm class
            * path - Additional path
            * severity - Alarm severity
            * vars - dict of alarm vars
        :return:
        """
        from noc.fm.models.activealarm import ActiveAlarm

        now = datetime.datetime.now()
        umbrella = ActiveAlarm.objects.filter(
            alarm_class=umbrella_cls.id,
            managed_object=self.object.id).first()
        u_sev = sum(d.get("severity", 0) for d in details)
        if not umbrella and not details:
            # No money, no honey
            return
        elif not umbrella and details:
            # Open new umbrella
            umbrella = ActiveAlarm(timestamp=now,
                                   managed_object=self.object.id,
                                   alarm_class=umbrella_cls.id,
                                   severity=u_sev)
            umbrella.save()
            self.logger.info("Opening umbrella alarm %s (%s)", umbrella.id,
                             umbrella_cls.name)
        elif umbrella and not details:
            # Close existing umbrella
            self.logger.info("Clearing umbrella alarm %s (%s)", umbrella.id,
                             umbrella_cls.name)
            umbrella.clear_alarm("Closing umbrella")
        elif umbrella and details and u_sev != umbrella.severity:
            self.logger.info("Change umbrella alarm %s severity %s -> %s (%s)",
                             umbrella.id, umbrella.severity, u_sev,
                             umbrella_cls.name)
            umbrella.change_severity(severity=u_sev)
        # Get existing details for umbrella
        active_details = {}  # (alarm class, path) -> alarm
        if umbrella:
            for da in ActiveAlarm.objects.filter(root=umbrella.id):
                d_path = da.vars.get("path", "")
                active_details[da.alarm_class, d_path] = da
        # Synchronize details
        self.logger.info("Active details: %s" % active_details)
        seen = set()
        for d in details:
            d_path = d.get("path", "")
            d_key = (d["alarm_class"], d_path)
            d_sev = d.get("severity", 0)
            # Key for seen details
            seen.add(d_key)
            if d_key in active_details and active_details[
                    d_key].severity != d_sev:
                # Change severity
                self.logger.info("Change detail alarm %s severity %s -> %s",
                                 active_details[d_key].id,
                                 active_details[d_key].severity, d_sev)
                active_details[d_key].change_severity(severity=d_sev)
            elif d_key not in active_details:
                # Create alarm
                self.logger.info("Create detail alarm to path %s", d_key)
                v = d.get("vars", {})
                v["path"] = d_path
                da = ActiveAlarm(timestamp=now,
                                 managed_object=self.object.id,
                                 alarm_class=d["alarm_class"],
                                 severity=d_sev,
                                 vars=v,
                                 root=umbrella.id)
                da.save()
                self.logger.info("Opening detail alarm %s %s (%s)", da.id,
                                 d_path, da.alarm_class.name)
        # Close details when necessary
        for d in set(active_details) - seen:
            self.logger.info("Clearing detail alarm %s", active_details[d].id)
            active_details[d].clear_alarm("Closing")

    def update_alarms(self):
        from noc.fm.models.alarmseverity import AlarmSeverity
        from noc.fm.models.alarmclass import AlarmClass

        prev_status = self.context.get("umbrella_settings", False)
        current_status = self.can_update_alarms()
        self.context["umbrella_settings"] = current_status

        if not prev_status and not current_status:
            return
        self.logger.info("Updating alarm statuses")
        umbrella_cls = AlarmClass.get_by_name(self.umbrella_cls)
        if not umbrella_cls:
            self.logger.info(
                "No umbrella alarm class. Alarm statuses not updated")
            return
        details = []
        if current_status:
            fatal_weight = self.get_fatal_alarm_weight()
            weight = self.get_alarm_weight()
            for p in self.problems:
                if not p["alarm_class"]:
                    continue
                ac = AlarmClass.get_by_name(p["alarm_class"])
                if not ac:
                    self.logger.info("Unknown alarm class %s. Skipping",
                                     p["alarm_class"])
                    continue
                details += [{
                    "alarm_class":
                    ac,
                    "path":
                    p["path"],
                    "severity":
                    AlarmSeverity.severity_for_weight(
                        fatal_weight if p["fatal"] else weight),
                    "vars": {
                        "path": p["path"],
                        "message": p["message"]
                    }
                }]
        else:
            # Clean up all open alarms as they has been disabled
            details = []
        self.update_umbrella(umbrella_cls, details)

    def can_update_alarms(self):
        return False

    def get_fatal_alarm_weight(self):
        return 1

    def get_alarm_weight(self):
        return 1

    def set_artefact(self, name, value=None):
        """
        Set artefact (opaque structure to be passed to following checks)
        :param name: Artefact name
        :param value: Opaque value
        :return:
        """
        if not value:
            if name in self.artefacts:
                del self.artefacts[name]
        else:
            self.artefacts[name] = value or None

    def get_artefact(self, name):
        """
        Get artefact by name
        :param name: artefact name
        :return: artefact
        """
        return self.artefacts.get(name)

    def has_artefact(self, name):
        """
        Check job has existing artefact
        :param name: artefact name
        :return: True, if artefact exists, False otherwise
        """
        return name in self.artefacts
Esempio n. 7
0
class BaseScript(six.with_metaclass(BaseScriptMetaclass, object)):
    """
    Service Activation script base class
    """

    # Script name in form of <vendor>.<system>.<name>
    name = None
    # Default script timeout
    TIMEOUT = config.script.timeout
    # Default session timeout
    SESSION_IDLE_TIMEOUT = config.script.session_idle_timeout
    # Default access preferene
    DEFAULT_ACCESS_PREFERENCE = "SC"
    # Enable call cache
    # If True, script result will be cached and reused
    # during lifetime of parent script
    cache = False
    # Implemented interface
    interface = None
    # Scripts required by generic script.
    # For common scripts - empty list
    # For generics - list of pairs (script_name, interface)
    requires = []
    #
    base_logger = logging.getLogger(name or "script")
    #
    _x_seq = itertools.count()
    # Sessions
    session_lock = Lock()
    session_cli = {}
    session_mml = {}
    session_rtsp = {}
    # In session mode when active CLI session exists
    # * True -- reuse session
    # * False -- close session and run new without session context
    reuse_cli_session = True
    # In session mode:
    # Should we keep CLI session for reuse by next script
    # * True - keep CLI session for next script
    # * False - close CLI session
    keep_cli_session = True
    # Script-level matchers.
    # Override profile one
    matchers = {}

    # Error classes shortcuts
    ScriptError = ScriptError
    CLISyntaxError = CLISyntaxError
    CLIOperationError = CLIOperationError
    NotSupportedError = NotSupportedError
    UnexpectedResultError = UnexpectedResultError

    hexbin = {
        "0": "0000",
        "1": "0001",
        "2": "0010",
        "3": "0011",
        "4": "0100",
        "5": "0101",
        "6": "0110",
        "7": "0111",
        "8": "1000",
        "9": "1001",
        "a": "1010",
        "b": "1011",
        "c": "1100",
        "d": "1101",
        "e": "1110",
        "f": "1111",
    }

    cli_protocols = {
        "telnet": "noc.core.script.cli.telnet.TelnetCLI",
        "ssh": "noc.core.script.cli.ssh.SSHCLI",
        "beef": "noc.core.script.cli.beef.BeefCLI",
    }

    mml_protocols = {"telnet": "noc.core.script.mml.telnet.TelnetMML"}

    rtsp_protocols = {"tcp": "noc.core.script.rtsp.base.RTSPBase"}
    # Override access preferences for script
    # S - always try SNMP first
    # C - always try CLI first
    # None - use default preferences
    always_prefer = None

    def __init__(
        self,
        service,
        credentials,
        args=None,
        capabilities=None,
        version=None,
        parent=None,
        timeout=None,
        name=None,
        session=None,
        session_idle_timeout=None,
    ):
        self.service = service
        self.tos = config.activator.tos
        self.pool = config.pool
        self.parent = parent
        self._motd = None
        name = name or self.name
        self.logger = PrefixLoggerAdapter(
            self.base_logger,
            "%s] [%s" % (self.name, credentials.get("address", "-")))
        if self.parent:
            self.profile = self.parent.profile
        else:
            self.profile = profile_loader.get_profile(".".join(
                name.split(".")[:2]))()
        self.credentials = credentials or {}
        self.version = version or {}
        self.capabilities = capabilities or {}
        self.timeout = timeout or self.get_timeout()
        self.start_time = None
        self._interface = self.interface()
        self.args = self.clean_input(args) if args else {}
        self.cli_stream = None
        self.mml_stream = None
        self.rtsp_stream = None
        if self.parent:
            self.snmp = self.root.snmp
        elif self.is_beefed:
            self.snmp = BeefSNMP(self)
        else:
            self.snmp = SNMP(self)
        if self.parent:
            self.http = self.root.http
        else:
            self.http = HTTP(self)
        self.to_disable_pager = not self.parent and self.profile.command_disable_pager
        self.scripts = ScriptsHub(self)
        # Store session id
        self.session = session
        self.session_idle_timeout = session_idle_timeout or self.SESSION_IDLE_TIMEOUT
        # Cache CLI and SNMP calls, if set
        self.is_cached = False
        # Suitable only when self.parent is None.
        # Cached results for scripts marked with "cache"
        self.call_cache = {}
        # Suitable only when self.parent is None
        # Cached results of self.cli calls
        self.cli_cache = {}
        #
        self.http_cache = {}
        self.partial_result = None
        # @todo: Get native encoding from ManagedObject
        self.native_encoding = "utf8"
        # Tracking
        self.to_track = False
        self.cli_tracked_data = {}  # command -> [packets]
        self.cli_tracked_command = None
        # state -> [..]
        self.cli_fsm_tracked_data = {}
        #
        if not parent and version and not name.endswith(".get_version"):
            self.logger.debug("Filling get_version cache with %s", version)
            s = name.split(".")
            self.set_cache("%s.%s.get_version" % (s[0], s[1]), {}, version)
        # Fill matchers
        if not self.name.endswith(".get_version"):
            self.apply_matchers()
        #
        if self.profile.setup_script:
            self.profile.setup_script(self)

    def __call__(self, *args, **kwargs):
        self.args = kwargs
        return self.run()

    def apply_matchers(self):
        """
        Process matchers and apply is_XXX properties
        :return:
        """
        def get_matchers(c, matchers):
            return dict((m, match(c, matchers[m])) for m in matchers)

        # Match context
        # @todo: Add capabilities
        ctx = self.version or {}
        if self.capabilities:
            ctx["caps"] = self.capabilities
        # Calculate matches
        v = get_matchers(ctx, self.profile.matchers)
        v.update(get_matchers(ctx, self.matchers))
        #
        for k in v:
            self.logger.debug("%s = %s", k, v[k])
            setattr(self, k, v[k])

    def clean_input(self, args):
        """
        Cleanup input parameters against interface
        """
        return self._interface.script_clean_input(self.profile, **args)

    def clean_output(self, result):
        """
        Clean script result against interface
        """
        return self._interface.script_clean_result(self.profile, result)

    def run(self):
        """
        Run script
        """
        with Span(server="activator",
                  service=self.name,
                  in_label=self.credentials.get("address")):
            self.start_time = perf_counter()
            self.logger.debug("Running. Input arguments: %s, timeout %s",
                              self.args, self.timeout)
            # Use cached result when available
            cache_hit = False
            if self.cache and self.parent:
                try:
                    result = self.get_cache(self.name, self.args)
                    self.logger.info("Using cached result")
                    cache_hit = True
                except KeyError:
                    pass
            # Execute script
            if not cache_hit:
                try:
                    result = self.execute(**self.args)
                    if self.cache and self.parent and result:
                        self.logger.info("Caching result")
                        self.set_cache(self.name, self.args, result)
                finally:
                    if not self.parent:
                        # Close SNMP socket when necessary
                        self.close_snmp()
                        # Close CLI socket when necessary
                        self.close_cli_stream()
                        # Close MML socket when necessary
                        self.close_mml_stream()
                        # Close RTSP socket when necessary
                        self.close_rtsp_stream()
                        # Close HTTP Client
                        self.http.close()
            # Clean result
            result = self.clean_output(result)
            self.logger.debug("Result: %s", result)
            runtime = perf_counter() - self.start_time
            self.logger.info("Complete (%.2fms)", runtime * 1000)
        return result

    @classmethod
    def compile_match_filter(cls, *args, **kwargs):
        """
        Compile arguments into version check function
        Returns callable accepting self and version hash arguments
        """
        c = [lambda self, x, g=f: g(x) for f in args]
        for k, v in six.iteritems(kwargs):
            # Split to field name and lookup operator
            if "__" in k:
                f, o = k.split("__")
            else:
                f = k
                o = "exact"
                # Check field name
            if f not in ("vendor", "platform", "version", "image"):
                raise Exception("Invalid field '%s'" % f)
                # Compile lookup functions
            if o == "exact":
                c += [lambda self, x, f=f, v=v: x[f] == v]
            elif o == "iexact":
                c += [lambda self, x, f=f, v=v: x[f].lower() == v.lower()]
            elif o == "startswith":
                c += [lambda self, x, f=f, v=v: x[f].startswith(v)]
            elif o == "istartswith":
                c += [
                    lambda self, x, f=f, v=v: x[f].lower().startswith(v.lower(
                    ))
                ]
            elif o == "endswith":
                c += [lambda self, x, f=f, v=v: x[f].endswith(v)]
            elif o == "iendswith":
                c += [
                    lambda self, x, f=f, v=v: x[f].lower().endswith(v.lower())
                ]
            elif o == "contains":
                c += [lambda self, x, f=f, v=v: v in x[f]]
            elif o == "icontains":
                c += [lambda self, x, f=f, v=v: v.lower() in x[f].lower()]
            elif o == "in":
                c += [lambda self, x, f=f, v=v: x[f] in v]
            elif o == "regex":
                c += [
                    lambda self, x, f=f, v=re.compile(v): v.search(x[f]) is
                    not None
                ]
            elif o == "iregex":
                c += [
                    lambda self, x, f=f, v=re.compile(v, re.IGNORECASE): v.
                    search(x[f]) is not None
                ]
            elif o == "isempty":  # Empty string or null
                c += [lambda self, x, f=f, v=v: not x[f] if v else x[f]]
            elif f == "version":
                if o == "lt":  # <
                    c += [
                        lambda self, x, v=v: self.profile.cmp_version(
                            x["version"], v) < 0
                    ]
                elif o == "lte":  # <=
                    c += [
                        lambda self, x, v=v: self.profile.cmp_version(
                            x["version"], v) <= 0
                    ]
                elif o == "gt":  # >
                    c += [
                        lambda self, x, v=v: self.profile.cmp_version(
                            x["version"], v) > 0
                    ]
                elif o == "gte":  # >=
                    c += [
                        lambda self, x, v=v: self.profile.cmp_version(
                            x["version"], v) >= 0
                    ]
                else:
                    raise Exception("Invalid lookup operation: %s" % o)
            else:
                raise Exception("Invalid lookup operation: %s" % o)
        # Combine expressions into single lambda
        return reduce(
            lambda x, y: lambda self, v, x=x, y=y: (x(self, v) and y(self, v)),
            c,
            lambda self, x: True,
        )

    @classmethod
    def match(cls, *args, **kwargs):
        """
        execute method decorator
        """
        def wrap(f):
            # Append to the execute chain
            if hasattr(f, "_match"):
                old_filter = f._match
                f._match = lambda self, v, old_filter=old_filter, new_filter=new_filter: new_filter(
                    self, v) or old_filter(self, v)
            else:
                f._match = new_filter
            f._seq = next(cls._x_seq)
            return f

        # Compile check function
        new_filter = cls.compile_match_filter(*args, **kwargs)
        # Return decorated function
        return wrap

    def match_version(self, *args, **kwargs):
        """
        inline version for BaseScript.match
        """
        if not self.version:
            self.version = self.scripts.get_version()
        return self.compile_match_filter(*args, **kwargs)(self, self.version)

    def execute(self, **kwargs):
        """
        Default script behavior:
        Pass through _execute_chain and call appropriate handler
        """
        if self._execute_chain and not self.name.endswith(".get_version"):
            # Deprecated @match chain
            self.logger.info(
                "WARNING: Using deprecated @BaseScript.match() decorator. "
                "Consider porting to the new matcher API")
            # Get version information
            if not self.version:
                self.version = self.scripts.get_version()
            # Find and execute proper handler
            for f in self._execute_chain:
                if f._match(self, self.version):
                    return f(self, **kwargs)
                # Raise error
            raise self.NotSupportedError()
        else:
            # New SNMP/CLI API
            return self.call_method(cli_handler=self.execute_cli,
                                    snmp_handler=self.execute_snmp,
                                    **kwargs)

    def call_method(self,
                    cli_handler=None,
                    snmp_handler=None,
                    fallback_handler=None,
                    **kwargs):
        """
        Call function depending on access_preference
        :param cli_handler: String or callable to call on CLI access method
        :param snmp_handler: String or callable to call on SNMP access method
        :param fallback_handler: String or callable to call if no access method matched
        :param kwargs:
        :return:
        """
        # Select proper handler
        access_preference = self.get_access_preference() + "*"
        for m in access_preference:
            # Select proper handler
            if m == "C":
                handler = cli_handler
            elif m == "S":
                if self.has_snmp():
                    handler = snmp_handler
                else:
                    self.logger.debug(
                        "SNMP is not enabled. Passing to next method")
                    continue
            elif m == "*":
                handler = fallback_handler
            else:
                raise self.NotSupportedError("Invalid access method '%s'" % m)
            # Resolve handler when necessary
            if isinstance(handler, six.string_types):
                handler = getattr(self, handler, None)
            if handler is None:
                self.logger.debug("No '%s' handler. Passing to next method" %
                                  m)
                continue
            # Call handler
            try:
                r = handler(**kwargs)
                if isinstance(r, PartialResult):
                    if self.partial_result:
                        self.partial_result.update(r.result)
                    else:
                        self.partial_result = r.result
                    self.logger.debug(
                        "Partial result: %r. Passing to next method",
                        self.partial_result)
                else:
                    return r
            except self.snmp.TimeOutError:
                self.logger.info("SNMP timeout. Passing to next method")
                if access_preference == "S*":
                    self.logger.info("Last S method break by timeout.")
                    raise self.snmp.TimeOutError
            except NotImplementedError:
                self.logger.debug(
                    "Access method '%s' is not implemented. Passing to next method",
                    m)
        raise self.NotSupportedError(
            "Access preference '%s' is not supported" % access_preference[:-1])

    def execute_cli(self, **kwargs):
        """
        Process script using CLI
        :param kwargs:
        :return:
        """
        raise NotImplementedError("execute_cli() is not implemented")

    def execute_snmp(self, **kwargs):
        """
        Process script using SNMP
        :param kwargs:
        :return:
        """
        raise NotImplementedError("execute_snmp() is not implemented")

    def cleaned_config(self, config):
        """
        Clean up config from all unnecessary trash
        """
        return self.profile.cleaned_config(config)

    def strip_first_lines(self, text, lines=1):
        """
        Strip first *lines*
        """
        t = text.split("\n")
        if len(t) <= lines:
            return ""
        else:
            return "\n".join(t[lines:])

    def expand_rangelist(self, s):
        """
        Expand expressions like "1,2,5-7" to [1, 2, 5, 6, 7]
        """
        result = {}
        for x in s.split(","):
            x = x.strip()
            if x == "":
                continue
            if "-" in x:
                left, right = [int(y) for y in x.split("-")]
                if left > right:
                    x = right
                    right = left
                    left = x
                for i in range(left, right + 1):
                    result[i] = None
            else:
                result[int(x)] = None
        return sorted(result.keys())

    rx_detect_sep = re.compile(r"^(.*?)\d+$")

    def expand_interface_range(self, s):
        """
        Convert interface range expression to a list
        of interfaces
        "Gi 1/1-3,Gi 1/7" -> ["Gi 1/1", "Gi 1/2", "Gi 1/3", "Gi 1/7"]
        "1:1-3" -> ["1:1", "1:2", "1:3"]
        "1:1-1:3" -> ["1:1", "1:2", "1:3"]
        :param s: Comma-separated list
        :return:
        """
        r = set()
        for x in s.split(","):
            x = x.strip()
            if not x:
                continue
            if "-" in x:
                # Expand range
                f, t = [y.strip() for y in x.split("-")]
                # Detect common prefix
                match = self.rx_detect_sep.match(f)
                if not match:
                    raise ValueError(x)
                prefix = match.group(1)
                # Detect range boundaries
                start = int(f[len(prefix):])
                if is_int(t):
                    stop = int(t)  # Just integer
                else:
                    if not t.startswith(prefix):
                        raise ValueError(x)
                    stop = int(t[len(prefix):])  # Prefixed
                if start > stop:
                    raise ValueError(x)
                for i in range(start, stop + 1):
                    r.add(prefix + str(i))
            else:
                r.add(x)
        return sorted(r)

    def macs_to_ranges(self, macs):
        """
        Converts list of macs to rangea
        :param macs: Iterable yielding mac addresses
        :returns: [(from, to), ..]
        """
        r = []
        for m in sorted(MAC(x) for x in macs):
            if r:
                if r[-1][1].shift(1) == m:
                    # Expand last range
                    r[-1][1] = m
                else:
                    r += [[m, m]]
            else:
                r += [[m, m]]
        return [(str(x[0]), str(x[1])) for x in r]

    def hexstring_to_mac(self, s):
        """Convert a 6-octet string to MAC address"""
        return ":".join(["%02X" % ord(x) for x in s])

    @property
    def root(self):
        """Get root script"""
        if self.parent:
            return self.parent.root
        else:
            return self

    def get_cache(self, key1, key2):
        """Get cached result or raise KeyError"""
        s = self.root
        return s.call_cache[repr(key1)][repr(key2)]

    def set_cache(self, key1, key2, value):
        """Set cached result"""
        key1 = repr(key1)
        key2 = repr(key2)
        s = self.root
        if key1 not in s.call_cache:
            s.call_cache[key1] = {}
        s.call_cache[key1][key2] = value

    def configure(self):
        """Returns configuration context"""
        return ConfigurationContextManager(self)

    def cached(self):
        """
        Return cached context managed. All nested CLI and SNMP GET/GETNEXT
        calls will be cached.

        Usage:

        with self.cached():
            self.cli(".....)
            self.scripts.script()
        """
        return CacheContextManager(self)

    def enter_config(self):
        """Enter configuration mote"""
        if self.profile.command_enter_config:
            self.cli(self.profile.command_enter_config)

    def leave_config(self):
        """Leave configuration mode"""
        if self.profile.command_leave_config:
            self.cli(self.profile.command_leave_config)
            self.cli(
                ""
            )  # Guardian empty command to wait until configuration is finally written

    def save_config(self, immediately=False):
        """Save current config"""
        if immediately:
            if self.profile.command_save_config:
                self.cli(self.profile.command_save_config)
        else:
            self.schedule_to_save()

    def schedule_to_save(self):
        self.need_to_save = True
        if self.parent:
            self.parent.schedule_to_save()

    def set_motd(self, motd):
        self._motd = motd

    @property
    def motd(self):
        """
        Return message of the day
        """
        if self._motd:
            return self._motd
        return self.get_cli_stream().get_motd()

    def re_search(self, rx, s, flags=0):
        """
        Match s against regular expression rx using re.search
        Raise UnexpectedResultError if regular expression is not matched.
        Returns match object.
        rx can be string or compiled regular expression
        """
        if isinstance(rx, six.string_types):
            rx = re.compile(rx, flags)
        match = rx.search(s)
        if match is None:
            raise UnexpectedResultError()
        return match

    def re_match(self, rx, s, flags=0):
        """
        Match s against regular expression rx using re.match
        Raise UnexpectedResultError if regular expression is not matched.
        Returns match object.
        rx can be string or compiled regular expression
        """
        if isinstance(rx, six.string_types):
            rx = re.compile(rx, flags)
        match = rx.match(s)
        if match is None:
            raise UnexpectedResultError()
        return match

    _match_lines_cache = {}

    @classmethod
    def match_lines(cls, rx, s):
        k = id(rx)
        if k not in cls._match_lines_cache:
            _rx = [re.compile(line, re.IGNORECASE) for line in rx]
            cls._match_lines_cache[k] = _rx
        else:
            _rx = cls._match_lines_cache[k]
        ctx = {}
        idx = 0
        r = _rx[0]
        for line in s.splitlines():
            line = line.strip()
            match = r.search(line)
            if match:
                ctx.update(match.groupdict())
                idx += 1
                if idx == len(_rx):
                    return ctx
                r = _rx[idx]
        return None

    def find_re(self, iter, s):
        """
        Find first matching regular expression
        or raise Unexpected result error
        """
        for r in iter:
            if r.search(s):
                return r
        raise UnexpectedResultError()

    def hex_to_bin(self, s):
        """
        Convert hexadecimal string to boolean string.
        All non-hexadecimal characters are ignored
        :param s: Input string
        :return: Boolean string
        :rtype: str
        """
        return "".join(self.hexbin[c]
                       for c in "".join("%02x" % ord(d) for d in s))

    def push_prompt_pattern(self, pattern):
        self.get_cli_stream().push_prompt_pattern(pattern)

    def pop_prompt_pattern(self):
        self.get_cli_stream().pop_prompt_pattern()

    def has_oid(self, oid):
        """
        Check object responses to oid
        """
        try:
            return bool(self.snmp.get(oid))
        except self.snmp.TimeOutError:
            return False

    def get_timeout(self):
        return self.TIMEOUT

    def cli(
        self,
        cmd,
        command_submit=None,
        bulk_lines=None,
        list_re=None,
        cached=False,
        file=None,
        ignore_errors=False,
        allow_empty_response=True,
        nowait=False,
        obj_parser=None,
        cmd_next=None,
        cmd_stop=None,
    ):
        # type: (six.text_type, Optional[six.binary_type], Any, Any, bool, Optional[six.text_type], Any, Any, Any, Any, Any, Any) -> six.text_type
        """
        Execute CLI command and return result. Initiate cli session
        when necessary.

        if list_re is None, return a string
        if list_re is regular expression object, return a list of dicts (group name -> value),
            one dict per matched line

        :param cmd: CLI command to execute
        :param command_submit: Optional suffix to submit command. Profile's one used by default
        :param bulk_lines:
        :param list_re:
        :param cached: True if result of execution may be cached
        :param file: Path to the file containing debugging result
        :param ignore_errors:
        :param allow_empty_response: Allow empty output. If False - ignore prompt and wait output
        :param nowait:
        """
        def format_result(result):
            if list_re:
                x = []
                for l in result.splitlines():
                    match = list_re.match(l.strip())
                    if match:
                        x += [match.groupdict()]
                return x
            else:
                return result

        if file:
            # Read from file
            with open(file) as f:
                return format_result(f.read())
        if cached:
            # Cached result
            r = self.root.cli_cache.get(cmd)
            if r is not None:
                self.logger.debug("Use cached result")
                return format_result(r)
        # Effective command submit suffix
        if command_submit is None:
            command_submit = self.profile.command_submit
        # Encode submitted command
        submitted_cmd = smart_bytes(
            cmd, encoding=self.native_encoding) + command_submit
        # Run command
        stream = self.get_cli_stream()
        if self.to_track:
            self.cli_tracked_command = cmd
        r = stream.execute(
            submitted_cmd,
            obj_parser=obj_parser,
            cmd_next=cmd_next,
            cmd_stop=cmd_stop,
            ignore_errors=ignore_errors,
            allow_empty_response=allow_empty_response,
        )
        if isinstance(r, six.binary_type):
            r = smart_text(r, errors="ignore", encoding=self.native_encoding)
        if isinstance(r, six.text_type):
            # Check for syntax errors
            if not ignore_errors:
                # Then check for operation error
                if (self.profile.rx_pattern_operation_error
                        and self.profile.rx_pattern_operation_error.search(r)):
                    raise self.CLIOperationError(r)
            # Echo cancelation
            r = self.echo_cancelation(r, cmd)
            # Store cli cache when necessary
            if cached:
                self.root.cli_cache[cmd] = r
        return format_result(r)

    def echo_cancelation(self, r, cmd):
        # type: (six.text_type, six.text_type) -> six.text_type
        """
        Adaptive echo cancelation

        :param r:
        :param cmd:
        :return:
        """
        if r[:4096].lstrip().startswith(cmd):
            r = r.lstrip()
            if r.startswith(cmd + "\n"):
                # Remove first line
                r = self.strip_first_lines(r.lstrip())
            else:
                # Some switches, like ProCurve do not send \n after the echo
                r = r[len(cmd):]
        return r

    def get_cli_stream(self):
        if self.parent:
            return self.root.get_cli_stream()
        if not self.cli_stream and self.session:
            # Try to get cached session's CLI
            with self.session_lock:
                self.cli_stream = self.session_cli.get(self.session)
                if self.cli_stream:
                    if self.cli_stream.is_closed:
                        # Stream closed by external reason,
                        # mark as invalid and start new one
                        self.cli_stream = None
                    # Remove stream from pool to prevent cli session hijacking
                    del self.session_cli[self.session]
            if self.cli_stream:
                if self.to_reuse_cli_session():
                    self.logger.debug("Using cached session's CLI")
                    self.cli_stream.set_script(self)
                else:
                    self.logger.debug(
                        "Script cannot reuse existing CLI session, starting new one"
                    )
                    self.close_cli_stream()
        if not self.cli_stream:
            protocol = self.credentials.get("cli_protocol", "telnet")
            self.logger.debug("Open %s CLI", protocol)
            self.cli_stream = get_handler(self.cli_protocols[protocol])(
                self, tos=self.tos)
            # Store to the sessions
            if self.session:
                with self.session_lock:
                    self.session_cli[self.session] = self.cli_stream
            self.cli_stream.setup_session()
            # Disable pager when nesessary
            # @todo: Move to CLI
            if self.to_disable_pager:
                self.logger.debug("Disable paging")
                self.to_disable_pager = False
                if isinstance(self.profile.command_disable_pager,
                              six.string_types):
                    self.cli(self.profile.command_disable_pager,
                             ignore_errors=True)
                elif isinstance(self.profile.command_disable_pager, list):
                    for cmd in self.profile.command_disable_pager:
                        self.cli(cmd, ignore_errors=True)
                else:
                    raise UnexpectedResultError
        return self.cli_stream

    def close_cli_stream(self):
        if self.parent:
            return
        if self.cli_stream:
            if self.session and self.to_keep_cli_session():
                # Return cli stream to pool
                self.session_cli[self.session] = self.cli_stream
                # Schedule stream closing
                self.cli_stream.deferred_close(self.session_idle_timeout)
            else:
                self.cli_stream.shutdown_session()
                self.cli_stream.close()
            self.cli_stream = None

    def close_snmp(self):
        if self.parent:
            return
        if self.snmp:
            self.snmp.close()
            self.snmp = None

    def mml(self, cmd, **kwargs):
        """
        Execute MML command and return result. Initiate MML session when necessary
        :param cmd:
        :param kwargs:
        :return:
        """
        stream = self.get_mml_stream()
        r = stream.execute(cmd, **kwargs)
        return r

    def get_mml_stream(self):
        if self.parent:
            return self.root.get_mml_stream()
        if not self.mml_stream and self.session:
            # Try to get cached session's CLI
            with self.session_lock:
                self.mml_stream = self.session_mml.get(self.session)
                if self.mml_stream and self.mml_stream.is_closed:
                    self.mml_stream = None
                    del self.session_mml[self.session]
            if self.mml_stream:
                if self.to_reuse_cli_session():
                    self.logger.debug("Using cached session's MML")
                    self.mml_stream.set_script(self)
                else:
                    self.logger.debug(
                        "Script cannot reuse existing MML session, starting new one"
                    )
                    self.close_mml_stream()
        if not self.mml_stream:
            protocol = self.credentials.get("cli_protocol", "telnet")
            self.logger.debug("Open %s MML", protocol)
            self.mml_stream = get_handler(self.mml_protocols[protocol])(
                self, tos=self.tos)
            # Store to the sessions
            if self.session:
                with self.session_lock:
                    self.session_mml[self.session] = self.mml_stream
        return self.mml_stream

    def close_mml_stream(self):
        if self.parent:
            return
        if self.mml_stream:
            if self.session and self.to_keep_cli_session():
                self.mml_stream.deferred_close(self.session_idle_timeout)
            else:
                self.mml_stream.close()
            self.cli_stream = None

    def rtsp(self, method, path, **kwargs):
        """
        Execute RTSP command and return result. Initiate RTSP session when necessary
        :param method:
        :param path:
        :param kwargs:
        :return:
        """
        stream = self.get_rtsp_stream()
        r = stream.execute(path, method, **kwargs)
        return r

    def get_rtsp_stream(self):
        if self.parent:
            return self.root.get_rtsp_stream()
        if not self.rtsp_stream and self.session:
            # Try to get cached session's CLI
            with self.session_lock:
                self.rtsp_stream = self.session_rtsp.get(self.session)
                if self.rtsp_stream and self.rtsp_stream.is_closed:
                    self.rtsp_stream = None
                    del self.session_rtsp[self.session]
            if self.rtsp_stream:
                if self.to_reuse_cli_session():
                    self.logger.debug("Using cached session's RTSP")
                    self.rtsp_stream.set_script(self)
                else:
                    self.logger.debug(
                        "Script cannot reuse existing RTSP session, starting new one"
                    )
                    self.close_rtsp_stream()
        if not self.rtsp_stream:
            protocol = "tcp"
            self.logger.debug("Open %s RTSP", protocol)
            self.rtsp_stream = get_handler(self.rtsp_protocols[protocol])(
                self, tos=self.tos)
            # Store to the sessions
            if self.session:
                with self.session_lock:
                    self.session_rtsp[self.session] = self.rtsp_stream
        return self.rtsp_stream

    def close_rtsp_stream(self):
        if self.parent:
            return
        if self.rtsp_stream:
            if self.session and self.to_keep_cli_session():
                self.rtsp_stream.deferred_close(self.session_idle_timeout)
            else:
                self.rtsp_stream.close()
            self.cli_stream = None

    def close_current_session(self):
        if self.session:
            self.close_session(self.session)

    @classmethod
    def close_session(cls, session_id):
        """
        Explicit session closing
        :return:
        """
        with cls.session_lock:
            cli_stream = cls.session_cli.get(session_id)
            if cli_stream:
                del cls.session_cli[session_id]
            mml_stream = cls.session_mml.get(session_id)
            if mml_stream:
                del cls.session_mml[session_id]
            rtsp_stream = cls.session_rtsp.get(session_id)
            if rtsp_stream:
                del cls.session_rtsp[session_id]
        if cli_stream and not cli_stream.is_closed:
            cli_stream.shutdown_session()
            cli_stream.close()
        if mml_stream and not mml_stream.is_closed:
            mml_stream.shutdown_session()
            mml_stream.close()
        if rtsp_stream and not rtsp_stream.is_closed:
            rtsp_stream.shutdown_session()
            rtsp_stream.close()

    def get_access_preference(self):
        preferred = self.get_always_preferred()
        r = self.credentials.get("access_preference",
                                 self.DEFAULT_ACCESS_PREFERENCE)
        if preferred and preferred in r:
            return preferred + "".join(x for x in r if x != preferred)
        return r

    def get_always_preferred(self):
        """
        Return always preferred access method
        :return:
        """
        return self.always_prefer

    def has_cli_access(self):
        return "C" in self.get_access_preference()

    def has_snmp_access(self):
        return "S" in self.get_access_preference() and self.has_snmp()

    def has_cli_only_access(self):
        return self.has_cli_access() and not self.has_snmp_access()

    def has_snmp_only_access(self):
        return not self.has_cli_access() and self.has_snmp_access()

    def has_snmp(self):
        """
        Check whether equipment has SNMP enabled
        """
        if self.has_capability("SNMP", allow_zero=True):
            # If having SNMP caps - check it and credential
            return bool(self.credentials.get(
                "snmp_ro")) and self.has_capability("SNMP")
        else:
            # if SNMP caps not exist check credential
            return bool(self.credentials.get("snmp_ro"))

    def has_snmp_v1(self):
        return self.has_capability("SNMP | v1")

    def has_snmp_v2c(self):
        return self.has_capability("SNMP | v2c")

    def has_snmp_v3(self):
        return self.has_capability("SNMP | v3")

    def has_snmp_bulk(self):
        """
        Check whether equipment supports SNMP BULK
        """
        return self.has_capability("SNMP | Bulk")

    def has_capability(self, capability, allow_zero=False):
        """
        Check whether equipment supports capability
        """
        if allow_zero:
            return self.capabilities.get(capability) is not None
        else:
            return bool(self.capabilities.get(capability))

    def ignored_exceptions(self, iterable):
        """
        Context manager to silently ignore specified exceptions
        """
        return IgnoredExceptionsContextManager(iterable)

    def iter_pairs(self, g, offset=0):
        """
        Convert iterable g to a pairs
        i.e.
        [1, 2, 3, 4] -> [(1, 2), (3, 4)]
        :param g: Iterable
        :param offset: Skip first recirds
        :return:
        """
        g = iter(g)
        if offset:
            for _ in range(offset):
                next(g)
        return zip(g, g)

    def to_reuse_cli_session(self):
        return self.reuse_cli_session

    def to_keep_cli_session(self):
        return self.keep_cli_session

    def start_tracking(self):
        self.logger.debug("Start tracking")
        self.to_track = True

    def stop_tracking(self):
        self.logger.debug("Stop tracking")
        self.to_track = False
        self.cli_tracked_data = {}

    def push_cli_tracking(self, r, state):
        if state == "prompt":
            if self.cli_tracked_command in self.cli_tracked_data:
                self.cli_tracked_data[self.cli_tracked_command] += [r]
            else:
                self.cli_tracked_data[self.cli_tracked_command] = [r]
        elif state in self.cli_fsm_tracked_data:
            self.cli_fsm_tracked_data[state] += [r]
        else:
            self.cli_fsm_tracked_data[state] = [r]

    def push_snmp_tracking(self, oid, tlv):
        self.logger.debug("PUSH SNMP %s: %r", oid, tlv)

    def iter_cli_tracking(self):
        """
        Yields command, packets for collected data
        :return:
        """
        for cmd in self.cli_tracked_data:
            self.logger.debug("Collecting %d tracked CLI items",
                              len(self.cli_tracked_data[cmd]))
            yield cmd, self.cli_tracked_data[cmd]
        self.cli_tracked_data = {}

    def iter_cli_fsm_tracking(self):
        for state in self.cli_fsm_tracked_data:
            yield state, self.cli_fsm_tracked_data[state]

    def request_beef(self):
        """
        Download and return beef
        :return:
        """
        if not hasattr(self, "_beef"):
            self.logger.debug("Requesting beef")
            beef_storage_url = self.credentials.get("beef_storage_url")
            beef_path = self.credentials.get("beef_path")
            if not beef_storage_url:
                self.logger.debug("No storage URL")
                self._beef = None
                return None
            if not beef_path:
                self.logger.debug("No beef path")
                self._beef = None
                return None
            from .beef import Beef

            beef = Beef.load(beef_storage_url, beef_path)
            self._beef = beef
        return self._beef

    @property
    def is_beefed(self):
        return self.credentials.get("cli_protocol") == "beef"
Esempio n. 8
0
class RPCProxy(object):
    """
    API Proxy
    """

    RPCError = RPCError

    def __init__(self, service, service_name, sync=False, hints=None):
        self._logger = PrefixLoggerAdapter(logger, service_name)
        self._service = service
        self._service_name = service_name
        self._api = service_name.split("-")[0]
        self._tid = itertools.count()
        self._transactions = {}
        self._hints = hints
        self._sync = sync

    def __getattr__(self, item):
        @tornado.gen.coroutine
        def _call(method, *args, **kwargs):
            @tornado.gen.coroutine
            def make_call(url, body, limit=3):
                req_headers = {
                    "X-NOC-Calling-Service": self._service.name,
                    "Content-Type": "text/json",
                }
                sample = 1 if span_ctx and span_id else 0
                with Span(
                        server=self._service_name,
                        service=method,
                        sample=sample,
                        context=span_ctx,
                        parent=span_id,
                ) as span:
                    if sample:
                        req_headers["X-NOC-Span-Ctx"] = span.span_context
                        req_headers["X-NOC-Span"] = span.span_id
                    code, headers, data = yield fetch(
                        url,
                        method="POST",
                        headers=req_headers,
                        body=body,
                        connect_timeout=CONNECT_TIMEOUT,
                        request_timeout=REQUEST_TIMEOUT,
                    )
                    # Process response
                    if code == 200:
                        raise tornado.gen.Return(data)
                    elif code == 307:
                        # Process redirect
                        if not limit:
                            raise RPCException("Redirects limit exceeded")
                        url = headers.get("location")
                        self._logger.debug("Redirecting to %s", url)
                        r = yield make_call(url, data, limit - 1)
                        raise tornado.gen.Return(r)
                    elif code in (598, 599):
                        span.error_code = code
                        self._logger.debug("Timed out")
                        raise tornado.gen.Return(None)
                    else:
                        span.error_code = code
                        raise RPCHTTPError("HTTP Error %s: %s" % (code, body))

            t0 = perf_counter()
            self._logger.debug(
                "[%sCALL>] %s.%s(%s, %s)",
                "SYNC " if self._sync else "",
                self._service_name,
                method,
                args,
                kwargs,
            )
            metrics["rpc_call", ("called_service", self._service_name),
                    ("method", method)] += 1
            tid = next(self._tid)
            msg = {"method": method, "params": list(args)}
            is_notify = "_notify" in kwargs
            if not is_notify:
                msg["id"] = tid
            body = ujson.dumps(msg)
            # Get services
            response = None
            for t in self._service.iter_rpc_retry_timeout():
                # Resolve service against service catalog
                if self._hints:
                    svc = random.choice(self._hints)
                else:
                    svc = yield self._service.dcs.resolve(self._service_name)
                response = yield make_call(
                    "http://%s/api/%s/" % (svc, self._api), body)
                if response:
                    break
                else:
                    yield tornado.gen.sleep(t)
            t = perf_counter() - t0
            self._logger.debug("[CALL<] %s.%s (%.2fms)", self._service_name,
                               method, t * 1000)
            if response:
                if not is_notify:
                    try:
                        result = ujson.loads(response)
                    except ValueError as e:
                        raise RPCHTTPError("Cannot decode json: %s" % e)
                    if result.get("error"):
                        self._logger.error("RPC call failed: %s",
                                           result["error"])
                        raise RPCRemoteError(
                            "RPC call failed: %s" % result["error"],
                            remote_code=result.get("code", None),
                        )
                    else:
                        raise tornado.gen.Return(result["result"])
                else:
                    # Notifications return None
                    raise tornado.gen.Return()
            else:
                raise RPCNoService("No active service %s found" %
                                   self._service_name)

        @tornado.gen.coroutine
        def async_wrapper(*args, **kwargs):
            result = yield _call(item, *args, **kwargs)
            raise tornado.gen.Return(result)

        def sync_wrapper(*args, **kwargs):
            @tornado.gen.coroutine
            def _sync_call():
                try:
                    r = yield _call(item, *args, **kwargs)
                    result.append(r)
                except tornado.gen.Return as e:
                    result.append(e.value)
                except Exception:
                    error.append(sys.exc_info())
                finally:
                    ev.set()

            ev = threading.Event()
            result = []
            error = []
            self._service.ioloop.add_callback(_sync_call)
            ev.wait()
            if error:
                six.reraise(*error[0])
            else:
                return result[0]

        if item.startswith("_"):
            return self.__dict__[item]
        span_ctx, span_id = get_current_span()
        if self._sync:
            return sync_wrapper
        else:
            return async_wrapper
Esempio n. 9
0
class SNMP(object):
    name = "snmp"

    class TimeOutError(NOCError):
        default_code = ERR_SNMP_TIMEOUT
        default_msg = "SNMP Timeout"

    class FatalTimeoutError(NOCError):
        default_code = ERR_SNMP_FATAL_TIMEOUT
        default_msg = "Fatal SNMP Timeout"

    SNMPError = SNMPError

    def __init__(self, script):
        self._script = weakref.ref(script)
        self.ioloop = None
        self.result = None
        self.logger = PrefixLoggerAdapter(script.logger, self.name)
        self.timeouts_limit = 0
        self.timeouts = 0
        self.socket = None

    @property
    def script(self):
        return self._script()

    def set_timeout_limits(self, n):
        """
        Set sequental timeouts l
        :param n:
        :return:
        """
        self.timeouts_limit = n
        self.timeouts = n

    def close(self):
        if self.socket:
            self.logger.debug("Closing UDP socket")
            self.socket.close()
            self.socket = None
        if self.ioloop:
            self.logger.debug("Closing IOLoop")
            self.ioloop.close(all_fds=True)
            self.ioloop = None

    def get_ioloop(self):
        if not self.ioloop:
            self.logger.debug("Creating IOLoop")
            self.ioloop = tornado.ioloop.IOLoop()
        return self.ioloop

    def get_socket(self):
        if not self.socket:
            self.logger.debug("Create UDP socket")
            self.socket = UDPSocket(ioloop=self.get_ioloop(),
                                    tos=self.script.tos)
        return self.socket

    def _get_snmp_version(self, version=None):
        if version is not None:
            return version
        if self.script.has_snmp_v2c():
            return SNMP_v2c
        elif self.script.has_snmp_v3():
            return SNMP_v3
        elif self.script.has_snmp_v1():
            return SNMP_v1
        return SNMP_v2c

    def get(self, oids, cached=False, version=None, raw_varbinds=False):
        """
        Perform SNMP GET request
        :param oid: string or list of oids
        :param cached: True if get results can be cached during session
        :param raw_varbinds: Return value in BER encoding
        :returns: eigther result scalar or dict of name -> value
        """
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_get(
                    address=self.script.credentials["address"],
                    oids=oids,
                    community=str(self.script.credentials["snmp_ro"]),
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket(),
                    version=version,
                    raw_varbinds=raw_varbinds,
                )
                self.timeouts = self.timeouts_limit
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    if self.timeouts_limit:
                        self.timeouts -= 1
                        if not self.timeouts:
                            raise self.FatalTimeoutError()
                    raise self.TimeOutError()
                else:
                    raise

        version = self._get_snmp_version(version)
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def set(self, *args):
        """
        Perform SNMP GET request
        :param oid: string or list of oids
        :returns: eigther result scalar or dict of name -> value
        """
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_set(
                    address=self.script.credentials["address"],
                    varbinds=varbinds,
                    community=str(self.script.credentials["snmp_rw"]),
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket(),
                )
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    raise self.TimeOutError()
                else:
                    raise

        if len(args) == 1:
            varbinds = args
        elif len(args) == 2:
            varbinds = [(args[0], args[1])]
        else:
            raise ValueError("Invalid varbinds")
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def count(self, oid, filter=None, version=None):
        """
        Iterate MIB subtree and count matching instances
        :param oid: OID
        :param filter: Callable accepting oid and value and returning boolean
        """
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_count(
                    address=self.script.credentials["address"],
                    oid=oid,
                    community=str(self.script.credentials["snmp_ro"]),
                    bulk=self.script.has_snmp_bulk(),
                    filter=filter,
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket(),
                    version=version,
                )
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    raise self.TimeOutError()
                else:
                    raise

        version = self._get_snmp_version(version)
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def getnext(
        self,
        oid,
        community_suffix=None,
        filter=None,
        cached=False,
        only_first=False,
        bulk=None,
        max_repetitions=None,
        version=None,
        max_retries=0,
        timeout=10,
        raw_varbinds=False,
    ):
        @tornado.gen.coroutine
        def run():
            try:
                self.result = yield snmp_getnext(
                    address=self.script.credentials["address"],
                    oid=oid,
                    community=str(self.script.credentials["snmp_ro"]),
                    bulk=self.script.has_snmp_bulk() if bulk is None else bulk,
                    max_repetitions=max_repetitions,
                    filter=filter,
                    only_first=only_first,
                    tos=self.script.tos,
                    ioloop=self.get_ioloop(),
                    udp_socket=self.get_socket(),
                    version=version,
                    max_retries=max_retries,
                    timeout=timeout,
                    raw_varbinds=raw_varbinds,
                )
            except SNMPError as e:
                if e.code == TIMED_OUT:
                    raise self.TimeOutError()
                else:
                    raise

        version = self._get_snmp_version(version)
        self.get_ioloop().run_sync(run)
        r, self.result = self.result, None
        return r

    def get_table(self, oid, community_suffix=None, cached=False):
        """
        GETNEXT wrapper. Returns a hash of <index> -> <value>
        """
        r = {}
        for o, v in self.getnext(oid,
                                 community_suffix=community_suffix,
                                 cached=cached):
            r[int(o.split(".")[-1])] = v
        return r

    def join_tables(self, oid1, oid2, community_suffix=None, cached=False):
        """
        Generator returning a rows of two snmp tables joined by index
        """
        t1 = self.get_table(oid1,
                            community_suffix=community_suffix,
                            cached=cached)
        t2 = self.get_table(oid2,
                            community_suffix=community_suffix,
                            cached=cached)
        for k1, v1 in six.iteritems(t1):
            try:
                yield v1, t2[k1]
            except KeyError:
                pass

    def get_tables(
        self,
        oids,
        community_suffix=None,
        bulk=False,
        min_index=None,
        max_index=None,
        cached=False,
        max_retries=0,
    ):
        """
        Query list of SNMP tables referenced by oids and yields
        tuples of (key, value1, ..., valueN)

        :param oids: List of OIDs
        :param community_suffix: Optional suffix to be added to community
        :param bulk: Use BULKGETNEXT if true
        :param min_index:
        :param max_index:
        :param cached:
        :param max_retries:
        :return:
        """
        def gen_table(oid):
            line = len(oid) + 1
            for o, v in self.getnext(
                    oid,
                    community_suffix=community_suffix,
                    cached=cached,
                    bulk=bulk,
                    max_retries=max_retries,
            ):
                yield tuple([int(x) for x in o[line:].split(".")]), v

        # Retrieve tables
        tables = [dict(gen_table(oid)) for oid in oids]
        # Generate index
        index = set()
        for t in tables:
            index.update(t)
        # Yield result
        for i in sorted(index):
            yield [".".join([str(x) for x in i])] + [t.get(i) for t in tables]

    def join(self, oids, community_suffix=None, cached=False, join="left"):
        """
        Query list of tables, merge by oid index
        Tables are records of:
        * <oid>.<index> = value

        join may be one of:
        * left
        * inner
        * outer

        Yield records of (<index>, <value1>, ..., <valueN>)
        """
        tables = [
            self.get_table(o, community_suffix=community_suffix, cached=cached)
            for o in oids
        ]
        if join == "left":
            lt = tables[1:]
            for k in sorted(tables[0]):
                yield tuple([k, tables[0][k]] + [t.get(k) for t in lt])
        elif join == "inner":
            keys = set(tables[0])
            for lt in tables[1:]:
                keys &= set(lt)
            for k in sorted(keys):
                yield tuple([k] + [t.get(k) for t in tables])
        elif join == "outer":
            keys = set(tables[0])
            for lt in tables[1:]:
                keys |= set(lt)
            for k in sorted(keys):
                yield tuple([k] + [t.get(k) for t in tables])

    def get_chunked(self, oids, chunk_size=20, timeout_limits=3):
        """
        Fetch list of oids splitting to several operations when necessary

        :param oids: List of oids
        :param chunk_size: Maximal GET chunk size
        :param timeout_limits: SNMP timeout limits
        :return: dict of oid -> value for all retrieved values
        """
        results = {}
        self.set_timeout_limits(timeout_limits)
        while oids:
            chunk, oids = oids[:chunk_size], oids[chunk_size:]
            chunk = dict((x, x) for x in chunk)
            try:
                results.update(self.get(chunk))
            except self.TimeOutError as e:
                self.logger.error("Failed to get SNMP OIDs %s: %s", oids, e)
            except self.FatalTimeoutError:
                self.logger.error("Fatal timeout error on: %s", oids)
                break
            except self.SNMPError as e:
                self.logger.error("SNMP error code %s", e.code)
        return results
Esempio n. 10
0
 def __init__(
     self,
     service,
     credentials,
     args=None,
     capabilities=None,
     version=None,
     parent=None,
     timeout=None,
     name=None,
     session=None,
     session_idle_timeout=None,
 ):
     self.service = service
     self.tos = config.activator.tos
     self.pool = config.pool
     self.parent = parent
     self._motd = None
     name = name or self.name
     self.logger = PrefixLoggerAdapter(
         self.base_logger,
         "%s] [%s" % (self.name, credentials.get("address", "-")))
     if self.parent:
         self.profile = self.parent.profile
     else:
         self.profile = profile_loader.get_profile(".".join(
             name.split(".")[:2]))()
     self.credentials = credentials or {}
     self.version = version or {}
     self.capabilities = capabilities or {}
     self.timeout = timeout or self.get_timeout()
     self.start_time = None
     self._interface = self.interface()
     self.args = self.clean_input(args) if args else {}
     self.cli_stream = None
     self.mml_stream = None
     self.rtsp_stream = None
     if self.parent:
         self.snmp = self.root.snmp
     elif self.is_beefed:
         self.snmp = BeefSNMP(self)
     else:
         self.snmp = SNMP(self)
     if self.parent:
         self.http = self.root.http
     else:
         self.http = HTTP(self)
     self.to_disable_pager = not self.parent and self.profile.command_disable_pager
     self.scripts = ScriptsHub(self)
     # Store session id
     self.session = session
     self.session_idle_timeout = session_idle_timeout or self.SESSION_IDLE_TIMEOUT
     # Cache CLI and SNMP calls, if set
     self.is_cached = False
     # Suitable only when self.parent is None.
     # Cached results for scripts marked with "cache"
     self.call_cache = {}
     # Suitable only when self.parent is None
     # Cached results of self.cli calls
     self.cli_cache = {}
     #
     self.http_cache = {}
     self.partial_result = None
     # @todo: Get native encoding from ManagedObject
     self.native_encoding = "utf8"
     # Tracking
     self.to_track = False
     self.cli_tracked_data = {}  # command -> [packets]
     self.cli_tracked_command = None
     # state -> [..]
     self.cli_fsm_tracked_data = {}
     #
     if not parent and version and not name.endswith(".get_version"):
         self.logger.debug("Filling get_version cache with %s", version)
         s = name.split(".")
         self.set_cache("%s.%s.get_version" % (s[0], s[1]), {}, version)
     # Fill matchers
     if not self.name.endswith(".get_version"):
         self.apply_matchers()
     #
     if self.profile.setup_script:
         self.profile.setup_script(self)
Esempio n. 11
0
File: base.py Progetto: 0pt1on/noc
class HTTP(object):
    HTTPError = HTTPError

    def __init__(self, script):
        self.script = script
        if script:  # For testing purposes
            self.logger = PrefixLoggerAdapter(script.logger, "http")
        self.headers = {}
        self.cookies = None
        self.session_started = False
        self.request_id = 1
        self.session_id = None
        self.request_middleware = None
        if self.script:  # For testing purposes
            self.setup_middleware()

    def get_url(self, path):
        address = self.script.credentials["address"]
        port = self.script.credentials.get("http_port")
        if port:
            address += ":%s" % port
        proto = self.script.credentials.get("http_protocol", "http")
        return "%s://%s%s" % (proto, address, path)

    def get(self,
            path,
            headers=None,
            cached=False,
            json=False,
            eof_mark=None,
            use_basic=False):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        :param use_basic: Use basic authentication
        """
        self.ensure_session()
        self.request_id += 1
        self.logger.debug("GET %s", path)
        if cached:
            cache_key = "get_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        user, password = None, None
        if use_basic:
            user = self.script.credentials.get("user")
            password = self.script.credentials.get("password")
        # Apply GET middleware
        url = self.get_url(path)
        hdr = self._get_effective_headers(headers)
        if self.request_middleware:
            for mw in self.request_middleware:
                url, _, hdr = mw.process_get(url, "", hdr)
        code, headers, result = fetch_sync(
            url,
            headers=hdr,
            request_timeout=60,
            follow_redirects=True,
            allow_proxy=False,
            validate_cert=False,
            eof_mark=eof_mark,
            user=user,
            password=password,
        )
        if not 200 <= code <= 299:
            raise HTTPError(msg="HTTP Error (%s)" % result[:256], code=code)
        self._process_cookies(headers)
        if json:
            try:
                result = ujson.loads(result)
            except ValueError as e:
                raise HTTPError("Failed to decode JSON: %s", e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def post(self,
             path,
             data,
             headers=None,
             cached=False,
             json=False,
             eof_mark=None,
             use_basic=False):
        """
        Perform HTTP GET request
        :param path: URI
        :param headers: Dict of additional headers
        :param cached: Cache result
        :param json: Decode json if set to True
        :param eof_mark: Waiting eof_mark in stream for end session (perhaps device return length 0)
        :param use_basic: Use basic authentication
        """
        self.ensure_session()
        self.request_id += 1
        self.logger.debug("POST %s %s", path, data)
        if cached:
            cache_key = "post_%s" % path
            r = self.script.root.http_cache.get(cache_key)
            if r is not None:
                self.logger.debug("Use cached result")
                return r
        user, password = None, None
        if use_basic:
            user = self.script.credentials.get("user")
            password = self.script.credentials.get("password")
        # Apply POST middleware
        url = self.get_url(path)
        hdr = self._get_effective_headers(headers)
        if self.request_middleware:
            for mw in self.request_middleware:
                url, data, hdr = mw.process_post(url, data, hdr)
        code, headers, result = fetch_sync(
            url,
            method="POST",
            body=data,
            headers=hdr,
            request_timeout=60,
            follow_redirects=True,
            allow_proxy=False,
            validate_cert=False,
            eof_mark=eof_mark,
            user=user,
            password=password,
        )
        if not 200 <= code <= 299:
            raise HTTPError(msg="HTTP Error (%s)" % result[:256], code=code)
        self._process_cookies(headers)
        if json:
            try:
                return ujson.loads(result)
            except ValueError as e:
                raise HTTPError(msg="Failed to decode JSON: %s" % e)
        self.logger.debug("Result: %r", result)
        if cached:
            self.script.root.http_cache[cache_key] = result
        return result

    def close(self):
        if self.session_started:
            self.shutdown_session()

    def _process_cookies(self, headers):
        """
        Process and store cookies from response headers
        :param headers:
        :return:
        """
        cdata = headers.get("Set-Cookie")
        if not cdata:
            return
        if not self.cookies:
            self.cookies = SimpleCookie()
        self.cookies.load(cdata)

    def get_cookie(self, name):
        """
        Get cookie name by value
        :param name:
        :return: Morsel object or None
        """
        if not self.cookies:
            return None
        return self.cookies.get(name)

    def _get_effective_headers(self, headers):
        """
        Append session headers when necessary. Apply effective cookies
        :param headers:
        :return:
        """
        if self.headers:
            if headers:
                headers = headers.copy()
            else:
                headers = {}
            headers.update(self.headers)
        elif not headers and self.cookies:
            headers = {}
        if self.cookies:
            headers["Cookie"] = self.cookies.output(header="").lstrip()
        return headers

    def set_header(self, name, value):
        """
        Set HTTP header to be set with all following requests
        :param name:
        :param value:
        :return:
        """
        self.logger.debug("Set header: %s = %s", name, value)
        self.headers[name] = str(value)

    def set_session_id(self, session_id):
        """
        Set session_id to be reused by middleware
        :param session_id:
        :return: None
        """
        if session_id is not None:
            self.session_id = session_id
        else:
            self.session_id = None

    def ensure_session(self):
        if not self.session_started:
            self.session_started = True
            self.setup_session()

    def setup_session(self):
        if self.script.profile.setup_http_session:
            self.logger.debug("Setup http session")
            self.script.profile.setup_http_session(self.script)

    def shutdown_session(self):
        if self.script.profile.shutdown_http_session:
            self.logger.debug("Shutdown http session")
            self.script.profile.shutdown_http_session(self.script)

    def setup_middleware(self):
        mw_list = self.script.profile.get_http_request_middleware(self.script)
        if not mw_list:
            return
        self.request_middleware = []
        for mw_cfg in mw_list:
            if isinstance(mw_cfg, tuple):
                name, cfg = mw_cfg
            else:
                name, cfg = mw_cfg, {}
            if "." in name:
                # Handler
                mw_cls = get_handler(name)
                assert mw_cls
                assert isinstance(mw_cls, BaseMiddleware)
            else:
                # Middleware name
                mw_cls = loader.get_class(name)
            self.request_middleware += [mw_cls(self, **cfg)]
Esempio n. 12
0
 def __init__(self):
     self.logger = PrefixLoggerAdapter(logger, self.name)
Esempio n. 13
0
 def __init__(self, script):
     self.script = script
     self.logger = PrefixLoggerAdapter(script.logger, "http")
Esempio n. 14
0
class CLI(object):
    name = "cli"
    default_port = None
    iostream_class = None
    BUFFER_SIZE = config.activator.buffer_size
    MATCH_TAIL = 256
    # Buffer to check missed ECMA control characters
    MATCH_MISSED_CONTROL_TAIL = 8
    # Retries on immediate disconnect
    CONNECT_RETRIES = config.activator.connect_retries
    # Timeout after immediate disconnect
    CONNECT_TIMEOUT = config.activator.connect_timeout
    # compiled capabilities
    HAS_TCP_KEEPALIVE = hasattr(socket, "SO_KEEPALIVE")
    HAS_TCP_KEEPIDLE = hasattr(socket, "TCP_KEEPIDLE")
    HAS_TCP_KEEPINTVL = hasattr(socket, "TCP_KEEPINTVL")
    HAS_TCP_KEEPCNT = hasattr(socket, "TCP_KEEPCNT")
    HAS_TCP_NODELAY = hasattr(socket, "TCP_NODELAY")
    # Time until sending first keepalive probe
    KEEP_IDLE = 10
    # Keepalive packets interval
    KEEP_INTVL = 10
    # Terminate connection after N keepalive failures
    KEEP_CNT = 3
    SYNTAX_ERROR_CODE = "+@@@NOC:SYNTAXERROR@@@+"

    class InvalidPagerPattern(Exception):
        pass

    def __init__(self, script, tos=None):
        self.script = script
        self.profile = script.profile
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        self.iostream = None
        self.motd = ""
        self.ioloop = None
        self.command = None
        self.prompt_stack = []
        self.patterns = self.profile.patterns.copy()
        self.buffer = ""
        self.is_started = False
        self.result = None
        self.error = None
        self.pattern_table = None
        self.collected_data = []
        self.tos = tos
        self.current_timeout = None
        self.is_closed = False
        self.close_timeout = None
        self.setup_complete = False
        self.to_raise_privileges = script.credentials.get(
            "raise_privileges", True)
        self.state = "start"
        # State retries
        self.super_password_retries = self.profile.cli_retries_super_password

    def close(self):
        self.script.close_current_session()
        self.close_iostream()
        if self.ioloop:
            self.logger.debug("Closing IOLoop")
            self.ioloop.close(all_fds=True)
            self.ioloop = None
        self.is_closed = True

    def close_iostream(self):
        if self.iostream:
            self.logger.debug("Closing IOStream")
            self.iostream.close()
            self.iostream = None

    def set_state(self, state):
        self.logger.debug("Changing state to <%s>", state)
        self.state = state

    def deferred_close(self, session_timeout):
        if self.is_closed or not self.iostream:
            return
        self.logger.debug("Setting close timeout to %ss", session_timeout)
        # Cannot call call_later directly due to
        # thread-safety problems
        # See tornado issue #1773
        tornado.ioloop.IOLoop.instance().add_callback(self._set_close_timeout,
                                                      session_timeout)

    def _set_close_timeout(self, session_timeout):
        """
        Wrapper to deal with IOLoop.add_timeout thread safety problem
        :param session_timeout:
        :return:
        """
        self.close_timeout = tornado.ioloop.IOLoop.instance().call_later(
            session_timeout, self.close)

    def create_iostream(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if self.tos:
            s.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, self.tos)
        if self.HAS_TCP_NODELAY:
            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        if self.HAS_TCP_KEEPALIVE:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
            if self.HAS_TCP_KEEPIDLE:
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
                             self.KEEP_IDLE)
            if self.HAS_TCP_KEEPINTVL:
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
                             self.KEEP_INTVL)
            if self.HAS_TCP_KEEPCNT:
                s.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, self.KEEP_CNT)
        return self.iostream_class(s, self)

    def set_timeout(self, timeout):
        if timeout:
            self.logger.debug("Setting timeout: %ss", timeout)
            self.current_timeout = datetime.timedelta(seconds=timeout)
        else:
            if self.current_timeout:
                self.logger.debug("Resetting timeouts")
            self.current_timeout = None

    def run_sync(self, func, *args, **kwargs):
        """
        Simplified implementation of IOLoop.run_sync
        to distinguish real TimeoutErrors from incomplete futures
        :param func:
        :param args:
        :param kwargs:
        :return:
        """
        future_cell = [None]

        def run():
            try:
                result = func(*args, **kwargs)
                if result is not None:
                    result = tornado.gen.convert_yielded(result)
                future_cell[0] = result
            except Exception:
                future_cell[0] = tornado.concurrent.TracebackFuture()
                future_cell[0].set_exc_info(sys.exc_info())
            self.ioloop.add_future(future_cell[0],
                                   lambda future: self.ioloop.stop())

        self.ioloop.add_callback(run)
        self.ioloop.start()
        if not future_cell[0].done():
            self.logger.info("Incomplete feature left. Restarting IOStream")
            self.close_iostream()
            # Retain cryptic message as is,
            # Mark feature as done
            future_cell[0].set_exception(
                tornado.gen.TimeoutError(
                    "Operation timed out after %s seconds" % None))
        return future_cell[0].result()

    def execute(self,
                cmd,
                obj_parser=None,
                cmd_next=None,
                cmd_stop=None,
                ignore_errors=False):
        if self.close_timeout:
            self.logger.debug("Removing close timeout")
            self.ioloop.remove_timeout(self.close_timeout)
            self.close_timeout = None
        self.buffer = ""
        self.command = cmd
        self.error = None
        self.ignore_errors = ignore_errors
        if not self.ioloop:
            self.logger.debug("Creating IOLoop")
            self.ioloop = tornado.ioloop.IOLoop()
        if obj_parser:
            parser = functools.partial(self.parse_object_stream, obj_parser,
                                       cmd_next, cmd_stop)
        else:
            parser = self.read_until_prompt
        with Span(server=self.script.credentials.get("address"),
                  service=self.name,
                  in_label=cmd) as s:
            self.run_sync(self.submit, parser)
            if self.error:
                if s:
                    s.error_text = str(self.error)
                raise self.error
            else:
                return self.result

    @tornado.gen.coroutine
    def submit(self, parser=None):
        # Create iostream and connect, when necessary
        if not self.iostream:
            self.iostream = self.create_iostream()
            address = (
                self.script.credentials.get("address"),
                self.script.credentials.get("cli_port", self.default_port),
            )
            self.logger.debug("Connecting %s", address)
            try:
                yield self.iostream.connect(address)
            except tornado.iostream.StreamClosedError:
                self.logger.debug("Connection refused")
                self.error = CLIConnectionRefused("Connection refused")
                raise tornado.gen.Return(None)
            self.logger.debug("Connected")
            yield self.iostream.startup()
        # Perform all necessary login procedures
        if not self.is_started:
            yield self.on_start()
            self.motd = yield self.read_until_prompt()
            self.script.set_motd(self.motd)
            self.is_started = True
        # Send command
        # @todo: encode to object's encoding
        if self.profile.batch_send_multiline or self.profile.command_submit not in self.command:
            yield self.send(self.command)
        else:
            # Send multiline commands line-by-line
            for cmd in self.command.split(self.profile.command_submit):
                # Send line
                yield self.send(cmd + self.profile.command_submit)
                # @todo: Await response
        parser = parser or self.read_until_prompt
        self.result = yield parser()
        self.logger.debug("Command: %s\n%s", self.command.strip(), self.result)
        if (self.profile.rx_pattern_syntax_error and not self.ignore_errors
                and parser == self.read_until_prompt
                and (self.profile.rx_pattern_syntax_error.search(self.result)
                     or self.result == self.SYNTAX_ERROR_CODE)):
            error_text = self.result
            if self.profile.send_on_syntax_error and self.name != "beef_cli":
                yield self.on_error_sequence(self.profile.send_on_syntax_error,
                                             self.command, error_text)
            self.error = self.script.CLISyntaxError(error_text)
            self.result = None
        raise tornado.gen.Return(self.result)

    def cleaned_input(self, s):
        """
        Clean up received input and wipe out control sequences
        and rogue chars
        """
        # Wipe out rogue chars
        if self.profile.rogue_chars:
            for rc in self.profile.rogue_chars:
                try:
                    s = rc.sub("", s)  # rc is compiled regular expression
                except AttributeError:
                    s = s.replace(rc, "")  # rc is a string
        # Clean control sequences
        return self.profile.cleaned_input(s)

    @tornado.gen.coroutine
    def send(self, cmd):
        # @todo: Apply encoding
        cmd = str(cmd)
        self.logger.debug("Send: %r", cmd)
        yield self.iostream.write(cmd)

    @tornado.gen.coroutine
    def read_until_prompt(self):
        connect_retries = self.CONNECT_RETRIES
        while True:
            try:
                f = self.iostream.read_bytes(self.BUFFER_SIZE, partial=True)
                if self.current_timeout:
                    r = yield tornado.gen.with_timeout(self.current_timeout, f)
                else:
                    r = yield f
                if r == self.SYNTAX_ERROR_CODE:
                    raise tornado.gen.Return(self.SYNTAX_ERROR_CODE)
                if self.script.to_track:
                    self.script.push_cli_tracking(r, self.state)
            except tornado.iostream.StreamClosedError:
                # Check if remote end closes connection just
                # after connection established
                if not self.is_started and connect_retries:
                    self.logger.info(
                        "Connection reset. %d retries left. Waiting %d seconds",
                        connect_retries,
                        self.CONNECT_TIMEOUT,
                    )
                    while connect_retries:
                        yield tornado.gen.sleep(self.CONNECT_TIMEOUT)
                        connect_retries -= 1
                        self.iostream = self.create_iostream()
                        address = (
                            self.script.credentials.get("address"),
                            self.script.credentials.get(
                                "cli_port", self.default_port),
                        )
                        self.logger.debug("Connecting %s", address)
                        try:
                            yield self.iostream.connect(address)
                            yield self.iostream.startup()
                            break
                        except tornado.iostream.StreamClosedError:
                            if not connect_retries:
                                raise tornado.iostream.StreamClosedError()
                    continue
                else:
                    raise tornado.iostream.StreamClosedError()
            except tornado.gen.TimeoutError:
                self.logger.info("Timeout error")
                # IOStream must be closed to prevent hanging read callbacks
                self.close_iostream()
                raise tornado.gen.TimeoutError("Timeout")
            self.logger.debug("Received: %r", r)
            # Clean input
            if self.buffer.find("\x1b", -self.MATCH_MISSED_CONTROL_TAIL) != -1:
                self.buffer = self.cleaned_input(self.buffer + r)
            else:
                self.buffer += self.cleaned_input(r)
            # Try to find matched pattern
            offset = max(0, len(self.buffer) - self.MATCH_TAIL)
            for rx, handler in six.iteritems(self.pattern_table):
                match = rx.search(self.buffer, offset)
                if match:
                    self.logger.debug("Match: %s", rx.pattern)
                    matched = self.buffer[:match.start()]
                    self.buffer = self.buffer[match.end():]
                    if isinstance(handler, tuple):
                        r = yield handler[0](matched, match, *handler[1:])
                    else:
                        r = yield handler(matched, match)
                    if r is not None:
                        raise tornado.gen.Return(r)
                    else:
                        break  # This state is processed

    @tornado.gen.coroutine
    def parse_object_stream(self, parser=None, cmd_next=None, cmd_stop=None):
        """
        :param cmd:
        :param command_submit:
        :param parser: callable accepting buffer and returning
                       (key, data, rest) or None.
                       key - string with object distinguisher
                       data - dict containing attributes
                       rest -- unparsed rest of string
        :param cmd_next: Sequence to go to the next page
        :param cmd_stop: Sequence to stop
        :return:
        """
        self.logger.debug("Parsing object stream")
        objects = []
        seen = set()
        buffer = ""
        repeats = 0
        r_key = None
        stop_sent = False
        done = False
        while not done:
            r = yield self.iostream.read_bytes(self.BUFFER_SIZE, partial=True)
            if self.script.to_track:
                self.script.push_cli_tracking(r, self.state)
            self.logger.debug("Received: %r", r)
            buffer = self.cleaned_input(buffer + r)
            # Check for syntax error
            if (self.profile.rx_pattern_syntax_error and not self.ignore_errors
                    and self.profile.rx_pattern_syntax_error.search(
                        self.buffer)):
                error_text = self.buffer
                if self.profile.send_on_syntax_error:
                    yield self.on_error_sequence(
                        self.profile.send_on_syntax_error, self.command,
                        error_text)
                self.error = self.script.CLISyntaxError(error_text)
                break
            # Then check for operation error
            if (self.profile.rx_pattern_operation_error
                    and self.profile.rx_pattern_operation_error.search(
                        self.buffer)):
                self.error = self.script.CLIOperationError(self.buffer)
                break
            # Parse all possible objects
            while buffer:
                pr = parser(buffer)
                if not pr:
                    break  # No new objects
                key, obj, buffer = pr
                if key not in seen:
                    seen.add(key)
                    objects += [obj]
                    repeats = 0
                    r_key = None
                elif r_key:
                    if r_key == key:
                        repeats += 1
                        if repeats >= 3 and cmd_stop and not stop_sent:
                            # Stop loop at final page
                            # After 3 repeats
                            self.logger.debug("Stopping stream. Sending %r" %
                                              cmd_stop)
                            self.send(cmd_stop)
                            stop_sent = True
                else:
                    r_key = key
                    if cmd_next:
                        self.logger.debug("Next screen. Sending %r" % cmd_next)
                        self.send(cmd_next)
            # Check for prompt
            for rx, handler in six.iteritems(self.pattern_table):
                offset = max(0, len(buffer) - self.MATCH_TAIL)
                match = rx.search(buffer, offset)
                if match:
                    self.logger.debug("Match: %s", rx.pattern)
                    matched = buffer[:match.start()]
                    buffer = self.buffer[match.end():]
                    r = handler(matched, match)
                    if r is not None:
                        self.logger.debug("Prompt matched")
                        done = True
                        break
        raise tornado.gen.Return(objects)

    def send_pager_reply(self, data, match):
        """
        Send proper pager reply
        """
        pg = match.group(0)
        for p, c in self.patterns["more_patterns_commands"]:
            if p.search(pg):
                self.collected_data += [data]
                self.send(c)
                return
        raise self.InvalidPagerPattern(pg)

    def expect(self, patterns, timeout=None):
        """
        Send command if not none and set reply patterns
        """
        self.pattern_table = {}
        for pattern_name in patterns:
            rx = self.patterns.get(pattern_name)
            if not rx:
                continue
            self.pattern_table[rx] = patterns[pattern_name]
        self.set_timeout(timeout)

    @tornado.gen.coroutine
    def on_start(self, data=None, match=None):
        self.set_state("start")
        if self.profile.setup_sequence and not self.setup_complete:
            self.expect({"setup": self.on_setup_sequence},
                        self.profile.cli_timeout_setup)
        else:
            self.expect(
                {
                    "username": self.on_username,
                    "password": self.on_password,
                    "unprivileged_prompt": self.on_unprivileged_prompt,
                    "prompt": self.on_prompt,
                    "pager": self.send_pager_reply,
                },
                self.profile.cli_timeout_start,
            )

    @tornado.gen.coroutine
    def on_username(self, data, match):
        self.set_state("username")
        self.send((self.script.credentials.get("user", "") or "") +
                  (self.profile.username_submit or "\n"))
        self.expect(
            {
                "username": (self.on_failure, CLIAuthFailed),
                "password": self.on_password,
                "unprivileged_prompt": self.on_unprivileged_prompt,
                "prompt": self.on_prompt,
            },
            self.profile.cli_timeout_user,
        )

    @tornado.gen.coroutine
    def on_password(self, data, match):
        self.set_state("password")
        self.send((self.script.credentials.get("password", "") or "") +
                  (self.profile.password_submit or "\n"))
        self.expect(
            {
                "username": (self.on_failure, CLIAuthFailed),
                "password": (self.on_failure, CLIAuthFailed),
                "unprivileged_prompt": self.on_unprivileged_prompt,
                "super_password": self.on_super_password,
                "prompt": self.on_prompt,
                "pager": self.send_pager_reply,
            },
            self.profile.cli_timeout_password,
        )

    @tornado.gen.coroutine
    def on_unprivileged_prompt(self, data, match):
        self.set_state("unprivileged_prompt")
        if self.to_raise_privileges:
            # Start privilege raising sequence
            if not self.profile.command_super:
                self.on_failure(data, match, CLINoSuperCommand)
            self.send(self.profile.command_super +
                      (self.profile.command_submit or "\n"))
            # Do not remove `pager` section
            # It fixes this situation on Huawei MA5300:
            # xxx>enable
            # { <cr>|level-value<U><1,15> }:
            # xxx#
            self.expect(
                {
                    "username": self.on_super_username,
                    "password": self.on_super_password,
                    "prompt": self.on_prompt,
                    "pager": self.send_pager_reply,
                },
                self.profile.cli_timeout_super,
            )
        else:
            # Do not raise privileges
            # Use unprivileged prompt as primary prompt
            self.patterns["prompt"] = self.patterns["unprivileged_prompt"]
            return self.on_prompt(data, match)

    @tornado.gen.coroutine
    def on_failure(self, data, match, error_cls=None):
        self.set_state("failure")
        error_cls = error_cls or CLIError
        raise error_cls(self.buffer or data or None)

    @tornado.gen.coroutine
    def on_prompt(self, data, match):
        self.set_state("prompt")
        if not self.is_started:
            self.resolve_pattern_prompt(match)
        d = "".join(self.collected_data + [data])
        self.collected_data = []
        self.expect({"prompt": self.on_prompt, "pager": self.send_pager_reply})
        return d

    @tornado.gen.coroutine
    def on_super_username(self, data, match):
        self.set_state("super_username")
        self.send((self.script.credentials.get("user", "") or "") +
                  (self.profile.username_submit or "\n"))
        self.expect(
            {
                "username": (self.on_failure, CLILowPrivileges),
                "password": self.on_super_password,
                "unprivileged_prompt": self.on_unprivileged_prompt,
                "prompt": self.on_prompt,
                "pager": self.send_pager_reply,
            },
            self.profile.cli_timeout_user,
        )

    @tornado.gen.coroutine
    def on_super_password(self, data, match):
        self.set_state("super_password")
        self.send((self.script.credentials.get("super_password", "") or "") +
                  (self.profile.username_submit or "\n"))
        if self.super_password_retries > 1:
            unprivileged_handler = self.on_unprivileged_prompt
            self.super_password_retries -= 1
        else:
            unprivileged_handler = (self.on_failure, CLILowPrivileges)
        self.expect(
            {
                "prompt": self.on_prompt,
                "password": (self.on_failure, CLILowPrivileges),
                "super_password": (self.on_failure, CLILowPrivileges),
                "pager": self.send_pager_reply,
                "unprivileged_prompt": unprivileged_handler,
            },
            self.profile.cli_timeout_password,
        )

    @tornado.gen.coroutine
    def on_setup_sequence(self, data, match):
        self.set_state("setup")
        self.logger.debug("Performing setup sequence: %s",
                          self.profile.setup_sequence)
        lseq = len(self.profile.setup_sequence)
        for i, c in enumerate(self.profile.setup_sequence):
            if isinstance(c, six.integer_types) or isinstance(c, float):
                yield tornado.gen.sleep(c)
                continue
            cmd = c % self.script.credentials
            yield self.send(cmd)
            # Waiting for response and drop it
            if i < lseq - 1:
                resp = yield tornado.gen.with_timeout(
                    self.ioloop.time() + 30,
                    future=self.iostream.read_bytes(4096, partial=True),
                    io_loop=self.ioloop,
                )
                if self.script.to_track:
                    self.script.push_cli_tracking(resp, self.state)
                self.logger.debug("Receiving: %r", resp)
        self.logger.debug("Setup sequence complete")
        self.setup_complete = True
        yield self.on_start(data, match)

    def resolve_pattern_prompt(self, match):
        """
        Resolve adaptive pattern prompt
        """
        old_pattern_prompt = self.patterns["prompt"].pattern
        pattern_prompt = old_pattern_prompt
        sl = self.profile.can_strip_hostname_to
        for k, v in six.iteritems(match.groupdict()):
            if v:
                if k == "hostname" and sl and len(v) > sl:
                    ss = list(reversed(v[sl:]))
                    v = re.escape(v[:sl]) + reduce(
                        lambda x, y: "(?:%s%s)?" % (re.escape(y), x),
                        ss[1:],
                        "(?:%s)?" % re.escape(ss[0]),
                    )
                else:
                    v = re.escape(v)
                pattern_prompt = replace_re_group(pattern_prompt,
                                                  "(?P<%s>" % k, v)
                pattern_prompt = replace_re_group(pattern_prompt, "(?P=%s" % k,
                                                  v)
            else:
                self.logger.error("Invalid prompt pattern")
        if old_pattern_prompt != pattern_prompt:
            self.logger.debug("Refining pattern prompt to %r", pattern_prompt)
        self.patterns["prompt"] = re.compile(pattern_prompt,
                                             re.DOTALL | re.MULTILINE)

    def push_prompt_pattern(self, pattern):
        """
        Override prompt pattern
        """
        self.logger.debug("New prompt pattern: %s", pattern)
        self.prompt_stack += [self.patterns["prompt"]]
        self.patterns["prompt"] = re.compile(pattern, re.DOTALL | re.MULTILINE)
        self.pattern_table[self.patterns["prompt"]] = self.on_prompt

    def pop_prompt_pattern(self):
        """
        Restore prompt pattern
        """
        self.logger.debug("Restore prompt pattern")
        pattern = self.prompt_stack.pop(-1)
        self.patterns["prompt"] = pattern
        self.pattern_table[self.patterns["prompt"]] = self.on_prompt

    def get_motd(self):
        """
        Return collected message of the day
        """
        return self.motd

    def set_script(self, script):
        self.script = script
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        if self.close_timeout:
            tornado.ioloop.IOLoop.instance().remove_timeout(self.close_timeout)
            self.close_timeout = None
        if self.motd:
            self.script.set_motd(self.motd)

    def setup_session(self):
        if self.profile.setup_session:
            self.logger.debug("Setup session")
            self.profile.setup_session(self.script)

    def shutdown_session(self):
        if self.profile.shutdown_session:
            self.logger.debug("Shutdown session")
            self.profile.shutdown_session(self.script)

    @tornado.gen.coroutine
    def on_error_sequence(self, seq, command, error_text):
        """
        Process error sequence
        :param seq:
        :param command:
        :param error_text:
        :return:
        """
        if isinstance(seq, six.string_types):
            self.logger.debug("Recovering from error. Sending %r", seq)
            yield self.iostream.write(seq)
        elif callable(seq):
            if tornado.gen.is_coroutine_function(seq):
                # Yield coroutine
                yield seq(self, command, error_text)
            else:
                seq = seq(self, command, error_text)
                yield self.iostream.write(seq)
Esempio n. 15
0
class BaseExtractor(object):
    """
    Data extractor interface. Subclasses must provide
    *iter_data* method
    """
    Problem = namedtuple("Problem", ["line", "is_rej", "p_class", "message", "row"])
    name = None
    PREFIX = config.path.etl_import
    REPORT_INTERVAL = 1000
    # List of rows to be used as constant data
    data = []
    # Suppress deduplication message
    suppress_deduplication_log = False

    def __init__(self, system):
        self.system = system
        self.config = system.config
        self.logger = PrefixLoggerAdapter(
            logger, "%s][%s" % (system.name, self.name)
        )
        self.import_dir = os.path.join(self.PREFIX, system.name, self.name)
        self.fatal_problems = []
        self.quality_problems = []

    def register_quality_problem(self, line, p_class, message, row):
        self.quality_problems += [self.Problem(line=line + 1, is_rej=False, p_class=p_class, message=message, row=row)]

    def register_fatal_problem(self, line, p_class, message, row):
        self.fatal_problems += [self.Problem(line=line + 1, is_rej=True, p_class=p_class, message=message, row=row)]

    def get_new_state(self):
        if not os.path.isdir(self.import_dir):
            self.logger.info("Creating directory %s", self.import_dir)
            os.makedirs(self.import_dir)
        path = os.path.join(self.import_dir, "import.csv.gz")
        self.logger.info("Writing to %s", path)
        return gzip.GzipFile(path, "w")

    def get_problem_file(self):
        if not os.path.isdir(self.import_dir):
            self.logger.info("Creating directory %s", self.import_dir)
            os.makedirs(self.import_dir)
        path = os.path.join(self.import_dir, "import.csv.rej.gz")
        self.logger.info("Writing to %s", path)
        return gzip.GzipFile(path, "w")

    def iter_data(self):
        for row in self.data:
            yield row

    def filter(self, row):
        return True

    def clean(self, row):
        return row

    def extract(self):
        def q(s):
            if s == "" or s is None:
                return ""
            elif isinstance(s, unicode):
                return s.encode("utf-8")
            else:
                return str(s)

        # Fetch data
        self.logger.info("Extracting %s from %s",
                         self.name, self.system.name)
        t0 = time.time()
        data = []
        n = 0
        seen = set()
        for row in self.iter_data():
            if not self.filter(row):
                continue
            row = self.clean(row)
            if row[0] in seen:
                if not self.suppress_deduplication_log:
                    self.logger.error("Duplicated row truncated: %r",
                                      row)
                continue
            else:
                seen.add(row[0])
            data += [[q(x) for x in row]]
            n += 1
            if n % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", n)
        dt = time.time() - t0
        speed = n / dt
        self.logger.info(
            "%d records extracted in %.2fs (%d records/s)",
            n, dt, speed
        )
        # Sort
        data.sort()
        # Write
        f = self.get_new_state()
        writer = csv.writer(f)
        writer.writerows(data)
        f.close()
        if self.fatal_problems or self.quality_problems:
            self.logger.warning("Detect problems on extracting, fatal: %d, quality: %d",
                                len(self.fatal_problems),
                                len(self.quality_problems))
            self.logger.warning("Line num\tType\tProblem string")
            for p in self.fatal_problems:
                self.logger.warning("Fatal problem, line was rejected: %s\t%s\t%s" % (p.line, p.p_class, p.message))
            for p in self.quality_problems:
                self.logger.warning("Data quality problem in line:  %s\t%s\t%s" % (p.line, p.p_class, p.message))
            # Dump problem to file
            try:
                f = self.get_problem_file()
                writer = csv.writer(f, delimiter=";")
                for p in itertools.chain(self.quality_problems, self.fatal_problems):
                    writer.writerow(
                        [str(c).encode("utf-8") for c in p.row] + ["Fatal problem, line was rejected" if p.is_rej else
                                                                   "Data quality problem"] + [p.message.encode("utf-8")]
                    )
            except IOError as e:
                self.logger.error("Error when saved problems %s", e)
            finally:
                f.close()
        else:
            self.logger.info("No problems detected")
Esempio n. 16
0
 def __init__(self, remote_system):
     self.remote_system = remote_system
     self.name = remote_system.name
     self.config = self.remote_system.config
     self.logger = PrefixLoggerAdapter(logger, self.name)
Esempio n. 17
0
class ProfileChecker(object):
    base_logger = logging.getLogger("profilechecker")
    _rules_cache = cachetools.TTLCache(10, ttl=60)
    _re_cache = {}

    def __init__(
        self,
        address=None,
        pool=None,
        logger=None,
        snmp_community=None,
        calling_service="profilechecker",
        snmp_version=None,
    ):
        self.address = address
        self.pool = pool
        self.logger = PrefixLoggerAdapter(
            logger or self.base_logger, "%s][%s" % (self.pool or "", self.address or "")
        )
        self.result_cache = {}  # (method, param) -> result
        self.error = None
        self.snmp_community = snmp_community
        self.calling_service = calling_service
        self.snmp_version = snmp_version or [SNMP_v2c]
        self.ignoring_snmp = False
        if self.snmp_version is None:
            self.logger.error("SNMP is not supported. Ignoring")
            self.ignoring_snmp = True
        if not self.snmp_community:
            self.logger.error("No SNMP credentials. Ignoring")
            self.ignoring_snmp = True

    def find_profile(self, method, param, result):
        """
        Find profile by method
        :param method: Fingerprint getting method
        :param param: Method params
        :param result: Getting params result
        :return:
        """
        r = defaultdict(list)
        d = self.get_rules()
        for k, value in sorted(six.iteritems(d), key=lambda x: x[0]):
            for v in value:
                r[v] += value[v]
        if (method, param) not in r:
            self.logger.warning("Not find rule for method: %s %s", method, param)
            return
        for match_method, value, action, profile, rname in r[(method, param)]:
            if self.is_match(result, match_method, value):
                self.logger.info("Matched profile: %s (%s)", profile, rname)
                # @todo: process MAYBE rule
                return profile

    def get_profile(self):
        """
        Returns profile for object, or None when not known
        """
        snmp_result = ""
        http_result = ""
        for ruleset in self.iter_rules():
            for (method, param), actions in ruleset:
                try:
                    result = self.do_check(method, param)
                    if not result:
                        continue
                    if "snmp" in method:
                        snmp_result = result
                    if "http" in method:
                        http_result = result
                    for match_method, value, action, profile, rname in actions:
                        if self.is_match(result, match_method, value):
                            self.logger.info("Matched profile: %s (%s)", profile, rname)
                            # @todo: process MAYBE rule
                            return profile
                except NOCError as e:
                    self.logger.error(e.message)
                    self.error = str(e.message)
                    return None
        if snmp_result or http_result:
            self.error = "Not find profile for OID: %s or HTTP string: %s" % (
                snmp_result,
                http_result,
            )
        elif not snmp_result:
            self.error = "Cannot fetch snmp data, check device for SNMP access"
        elif not http_result:
            self.error = "Cannot fetch HTTP data, check device for HTTP access"
        self.logger.info("Cannot detect profile: %s", self.error)
        return None

    def get_error(self):
        """
        Get error message
        :return:
        """
        return self.error

    @classmethod
    @cachetools.cachedmethod(operator.attrgetter("_rules_cache"), lock=lambda _: rules_lock)
    def get_profile_check_rules(cls):
        return list(ProfileCheckRule.objects.all().order_by("preference"))

    def get_rules(self):
        """
        Load ProfileCheckRules and return a list, grouped by preferences
        [{
            (method, param) -> [(
                    match_method,
                    value,
                    action,
                    profile,
                    rule_name
                ), ...]

        }]
        """
        self.logger.info('Compiling "Profile Check rules"')
        d = {}  # preference -> (method, param) -> [rule, ..]
        for r in self.get_profile_check_rules():
            if "snmp" in r.method and self.ignoring_snmp:
                continue
            if r.preference not in d:
                d[r.preference] = {}
            k = (r.method, r.param)
            if k not in d[r.preference]:
                d[r.preference][k] = []
            d[r.preference][k] += [(r.match_method, r.value, r.action, r.profile, r.name)]
        return d

    def iter_rules(self):
        d = self.get_rules()
        for p in sorted(d):
            yield list(six.iteritems(d[p]))

    @classmethod
    @cachetools.cachedmethod(operator.attrgetter("_re_cache"))
    def get_re(cls, regexp):
        return re.compile(regexp)

    def do_check(self, method, param):
        """
        Perform check
        """
        self.logger.debug("do_check(%s, %s)", method, param)
        if (method, param) in self.result_cache:
            self.logger.debug("Using cached value")
            return self.result_cache[method, param]
        h = getattr(self, "check_%s" % method, None)
        if not h:
            self.logger.error("Invalid check method '%s'. Ignoring", method)
            return None
        result = h(param)
        self.result_cache[method, param] = result
        return result

    def check_snmp_v2c_get(self, param):
        """
        Perform SNMP v2c GET. Param is OID or symbolic name
        """
        try:
            param = mib[param]
        except KeyError:
            self.logger.error("Cannot resolve OID '%s'. Ignoring", param)
            return None
        for v in self.snmp_version:
            if v == SNMP_v1:
                r = self.snmp_v1_get(param)
            elif v == SNMP_v2c:
                r = self.snmp_v2c_get(param)
            else:
                raise NOCError(msg="Unsupported SNMP version")
            if r:
                return r

    def check_http_get(self, param):
        """
        Perform HTTP GET check. Param can be URL path or :<port>/<path>
        """
        url = "http://%s%s" % (self.address, param)
        return self.http_get(url)

    def check_https_get(self, param):
        """
        Perform HTTPS GET check. Param can be URL path or :<port>/<path>
        """
        url = "https://%s%s" % (self.address, param)
        return self.https_get(url)

    def is_match(self, result, method, value):
        """
        Returns True when result matches value
        """
        if method == "eq":
            return result == value
        elif method == "contains":
            return value in result
        elif method == "re":
            return bool(self.get_re(value).search(result))
        else:
            self.logger.error("Invalid match method '%s'. Ignoring", method)
            return False

    def snmp_v1_get(self, param):
        """
        Perform SNMP v1 request. May be overridden for testing
        :param param:
        :return:
        """
        self.logger.info("SNMP v1 GET: %s", param)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).snmp_v1_get(self.address, self.snmp_community, param)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def snmp_v2c_get(self, param):
        """
        Perform SNMP v2c request. May be overridden for testing
        :param param:
        :return:
        """
        self.logger.info("SNMP v2c GET: %s", param)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).snmp_v2c_get(self.address, self.snmp_community, param)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def http_get(self, url):
        """
        Perform HTTP request. May be overridden for testing
        :param url: Request URL
        :return:
        """
        self.logger.info("HTTP Request: %s", url)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).http_get(url, True)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def https_get(self, url):
        """
        Perform HTTP request. May be overridden for testing
        :param url: Request URL
        :return:
        """
        return self.http_get(url)
Esempio n. 18
0
File: base.py Progetto: ewwwcha/noc
class BaseLoader(object):
    name = None
    base_cls = None  # Base class to be loaded
    base_path = None  # Tuple of path components
    ignored_names = set()

    def __init__(self):
        self.logger = PrefixLoggerAdapter(logger, self.name)
        self.classes = {}
        self.lock = threading.Lock()
        self.all_classes = set()

    def find_class(self, module_name, base_cls, name):
        """
        Load subclass of *base_cls* from module

        :param module_name: String containing module name
        :param base_cls: Base class
        :param name: object name
        :return: class reference or None
        """
        try:
            sm = __import__(module_name, {}, {}, "*")
            for n in dir(sm):
                o = getattr(sm, n)
                if (
                    inspect.isclass(o)
                    and issubclass(o, base_cls)
                    and o.__module__ == sm.__name__
                    and self.is_valid_class(o, name)
                ):
                    return o
        except ImportError as e:
            self.logger.error("Failed to load %s %s: %s", self.name, name, e)
        return None

    def is_valid_class(self, kls, name):
        """
        Check `find_class` found valid class
        :param kls: Class
        :param name: Class' name
        :return: True if class is valid and should be returned
        """
        return True

    def is_valid_name(self, name):
        return ".." not in name

    def get_path(self, base, name):
        """
        Get file path
        :param base: "" or custom prefix
        :param name: class name
        :return:
        """
        p = (base,) + self.base_path + ("%s.py" % name,)
        return os.path.join(*p)

    def get_module_name(self, base, name):
        """
        Get module name
        :param base: `noc` or custom prefix
        :param name: module name
        :return:
        """
        return "%s.%s.%s" % (base, ".".join(self.base_path), name)

    def get_class(self, name):
        with self.lock:
            kls = self.classes.get(name)
            if not kls:
                self.logger.info("Loading %s", name)
                if not self.is_valid_name(name):
                    self.logger.error("Invalid name: %s", name)
                    return None
                for p in config.get_customized_paths("", prefer_custom=True):
                    path = self.get_path(p, name)
                    if not os.path.exists(path):
                        continue
                    base_name = os.path.basename(os.path.dirname(p)) if p else "noc"
                    module_name = self.get_module_name(base_name, name)
                    kls = self.find_class(module_name, self.base_cls, name)
                    if kls:
                        break
                if not kls:
                    logger.error("DataStream not found: %s", name)
                self.classes[name] = kls
            return kls

    def __getitem__(self, item):
        return self.get_class(item)

    def __iter__(self):
        return self.iter_classes()

    def iter_classes(self):
        with self.lock:
            if not self.all_classes:
                self.all_classes = self.find_classes()
        for ds in sorted(self.all_classes):
            yield ds

    def find_classes(self):
        names = set()
        for dn in config.get_customized_paths(os.path.join(*self.base_path)):
            for fn in os.listdir(dn):
                if fn.startswith("_") or not fn.endswith(".py"):
                    continue
                name = fn[:-3]
                if name not in self.ignored_names:
                    names.add(name)
        return names
Esempio n. 19
0
class SegmentTopology(BaseTopology):
    def __init__(self,
                 segment,
                 node_hints=None,
                 link_hints=None,
                 force_spring=False):
        self.logger = PrefixLoggerAdapter(logger, segment.name)
        self.segment = segment
        self.segment_siblings = self.segment.get_siblings()
        self._uplinks_cache = {}
        self.segment_objects = set()
        if self.segment.parent:
            self.parent_segment = self.segment.parent
            self.ancestor_segments = set(self.segment.get_path()[:-1])
        else:
            self.parent_segment = None
            self.ancestor_segments = set()
        super(SegmentTopology, self).__init__(node_hints, link_hints,
                                              force_spring)

    def get_role(self, mo):
        if mo.segment in self.segment_siblings:
            return "segment"
        elif self.parent_segment and mo.segment.id in self.ancestor_segments:
            return "uplink"
        else:
            return "downlink"

    @cachetools.cachedmethod(operator.attrgetter("_uplinks_cache"))
    def get_uplinks(self):
        self.logger.info("Searching for uplinks")
        if not self.G:
            raise StopIteration
        for policy in self.segment.profile.iter_uplink_policy():
            uplinks = getattr(self, "get_uplinks_%s" % policy)()
            if uplinks:
                self.logger.info(
                    "[%s] %d uplinks found: %s",
                    policy,
                    len(uplinks),
                    ", ".join(str(x) for x in uplinks),
                )
                return uplinks
            self.logger.info("[%s] No uplinks found. Skipping", policy)
        self.logger.info("Failed to find uplinks")
        return []

    def get_uplinks_seghier(self):
        """
        Find uplinks basing on segment hierarchy. Any object with parent segment
        is uplink
        :return:
        """
        return [
            i for i in self.G.node if self.G.node[i].get("role") == "uplink"
        ]

    def get_uplinks_molevel(self):
        """
        Find uplinks basing on Managed Object's level. Top-leveled objects are returned.
        :return:
        """
        max_level = max(self.G.node[i].get("level") for i in self.G.node
                        if self.G.node[i].get("type") == "managedobject")
        return [
            i for i in self.G.node
            if self.G.node[i].get("type") == "managedobject"
            and self.G.node[i].get("level") == max_level
        ]

    def get_uplinks_seg(self):
        """
        All segment objects are uplinks
        :return:
        """
        return [
            i for i in self.G.node if self.G.node[i].get("role") == "segment"
        ]

    def get_uplinks_minaddr(self):
        """
        Segment's Object with lesser address is uplink
        :return:
        """
        s = next(
            iter(
                sorted((IP.prefix(self.G.node[i].get("address")), i)
                       for i in self.G.node
                       if self.G.node[i].get("role") == "segment")))
        return [s[1]]

    def get_uplinks_maxaddr(self):
        """
        Segment's Object with greater address is uplink
        :return:
        """
        s = next(
            reversed(
                sorted((IP.prefix(self.G.node[i].get("address")), i)
                       for i in self.G.node
                       if self.G.node[i].get("role") == "segment")))
        return [s[1]]

    def load(self):
        """
        Load all managed objects from segment
        """
        def get_bandwidth(if_list):
            """
            Calculate bandwidth for list of interfaces
            :param if_list:
            :return: total in bandwidth, total out bandwidth
            """
            in_bw = 0
            out_bw = 0
            for iface in if_list:
                bw = iface.get("bandwidth") or 0
                in_speed = iface.get("in_speed") or 0
                out_speed = iface.get("out_speed") or 0
                in_bw += bandwidth(in_speed, bw)
                out_bw += bandwidth(out_speed, bw)
            return in_bw, out_bw

        def bandwidth(speed, if_bw):
            if speed and if_bw:
                return min(speed, if_bw)
            elif speed and not if_bw:
                return speed
            elif if_bw:
                return if_bw
            else:
                return 0

        # Get all links, belonging to segment
        links = list(
            Link.objects.filter(
                linked_segments__in=[s.id for s in self.segment_siblings]))
        # All linked interfaces from map
        all_ifaces = list(
            itertools.chain.from_iterable(link.interface_ids
                                          for link in links))
        # Bulk fetch all interfaces data
        ifs = dict((i["_id"], i) for i in Interface._get_collection().find(
            {"_id": {
                "$in": all_ifaces
            }},
            {
                "_id": 1,
                "managed_object": 1,
                "name": 1,
                "bandwidth": 1,
                "in_speed": 1,
                "out_speed": 1,
            },
        ))
        # Bulk fetch all managed objects
        segment_mos = set(
            self.segment.managed_objects.values_list("id", flat=True))
        all_mos = list(
            set(i["managed_object"]
                for i in six.itervalues(ifs) if "managed_object" in i)
            | segment_mos)
        mos = dict(
            (mo.id, mo) for mo in ManagedObject.objects.filter(id__in=all_mos))
        self.segment_objects = set(mo_id for mo_id in all_mos
                                   if mos[mo_id].segment.id == self.segment.id)
        for mo in six.itervalues(mos):
            self.add_object(mo)
        # Process all segment's links
        pn = 0
        for link in links:
            if link.is_loop:
                continue  # Loops are not shown on map
            # Group interfaces by objects
            # avoiding non-bulk dereferencing
            mo_ifaces = defaultdict(list)
            for if_id in link.interface_ids:
                iface = ifs[if_id]
                mo_ifaces[mos[iface["managed_object"]]] += [iface]
            # Pairs of managed objects are pseudo-links
            if len(mo_ifaces) == 2:
                # ptp link
                pseudo_links = [list(mo_ifaces)]
                is_pmp = False
            else:
                # pmp
                # Create virtual cloud
                self.add_cloud(link)
                # Create virtual links to cloud
                pseudo_links = [(link, mo) for mo in mo_ifaces]
                # Create virtual cloud interface
                mo_ifaces[link] = [{"name": "cloud"}]
                is_pmp = True
            # Link all pairs
            for mo0, mo1 in pseudo_links:
                mo0_id = str(mo0.id)
                mo1_id = str(mo1.id)
                # Create virtual ports for mo0
                self.G.node[mo0_id]["ports"] += [{
                    "id":
                    pn,
                    "ports": [i["name"] for i in mo_ifaces[mo0]]
                }]
                # Create virtual ports for mo1
                self.G.node[mo1_id]["ports"] += [{
                    "id":
                    pn + 1,
                    "ports": [i["name"] for i in mo_ifaces[mo1]]
                }]
                # Calculate bandwidth
                t_in_bw, t_out_bw = get_bandwidth(mo_ifaces[mo0])
                d_in_bw, d_out_bw = get_bandwidth(mo_ifaces[mo1])
                in_bw = bandwidth(t_in_bw, d_out_bw) * 1000
                out_bw = bandwidth(t_out_bw, d_in_bw) * 1000
                # Add link
                if is_pmp:
                    link_id = "%s-%s-%s" % (link.id, pn, pn + 1)
                else:
                    link_id = str(link.id)
                self.add_link(
                    mo0_id,
                    mo1_id,
                    {
                        "id": link_id,
                        "type": "link",
                        "method": link.discovery_method,
                        "ports": [pn, pn + 1],
                        # Target to source
                        "in_bw": in_bw,
                        # Source to target
                        "out_bw": out_bw,
                        # Max bandwidth
                        "bw": max(in_bw, out_bw),
                    },
                )
                pn += 2

    def max_uplink_path_len(self):
        """
        Returns a maximum path length to uplink
        """
        n = 0
        uplinks = self.get_uplinks()
        for u in uplinks:
            for o in self.G.node:
                if o not in uplinks:
                    for p in nx.all_simple_paths(self.G, o, u):
                        n = max(n, len(p))
        return n

    def iter_uplinks(self):
        """
        Yields ObjectUplinks items for segment

        :returns: ObjectUplinks items
        """
        def get_node_uplinks(node):
            role = self.G.node[node].get("role", "cloud")
            if role == "uplink":
                # Only downlinks matter
                return []
            elif role == "downlink":
                # All segment neighbors are uplinks.
                # As no inter-downlink segment's links are loaded
                # so all neigbors are from current segment
                return list(self.G.neighbors(node))
            # Segment role and clouds
            ups = {}
            for u in uplinks:
                for path in nx.all_simple_paths(self.G, node, u):
                    lp = len(path)
                    p = path[1]
                    ups[p] = min(lp, ups.get(p, lp))
            # Shortest path first
            return sorted(ups, key=lambda x: ups[x])

        from noc.sa.models.objectdata import ObjectUplinks

        uplinks = self.get_uplinks()
        # @todo: Workaround for empty uplinks
        # Get uplinks for cloud nodes
        cloud_uplinks = dict((o, [int(u) for u in get_node_uplinks(o)])
                             for o in self.G.node
                             if self.G.node[o]["type"] == "cloud")
        # All objects including neighbors
        all_objects = set(o for o in self.G.node
                          if self.G.node[o]["type"] == "managedobject")
        # Get objects uplinks
        obj_uplinks = {}
        obj_downlinks = defaultdict(set)
        for o in all_objects:
            mo = int(o)
            ups = []
            for u in get_node_uplinks(o):
                cu = cloud_uplinks.get(u)
                if cu is not None:
                    # Uplink is a cloud. Use cloud's uplinks instead
                    ups += cu
                else:
                    ups += [int(u)]
            obj_uplinks[mo] = ups
            for u in ups:
                obj_downlinks[u].add(mo)
        # Calculate RCA neighbors and yield result
        for mo in obj_uplinks:
            # Filter out only current segment. Neighbors will be updated by their
            # segment's tasks
            if mo not in self.segment_objects:
                continue
            # All uplinks
            neighbors = set(obj_uplinks[mo])
            # All downlinks
            for dmo in obj_downlinks[mo]:
                neighbors.add(dmo)
                # And uplinks of downlinks
                neighbors |= set(obj_uplinks[dmo])
            # Not including object itself
            if mo in neighbors:
                neighbors.remove(mo)
            # Recalculated result
            yield ObjectUplinks(object_id=mo,
                                uplinks=obj_uplinks[mo],
                                rca_neighbors=list(sorted(neighbors)))
Esempio n. 20
0
def wipe(o):
    if not hasattr(o, "id"):
        try:
            o = ManagedObject.objects.get(id=o)
        except ManagedObject.DoesNotExist:
            return True
    log = PrefixLoggerAdapter(logger, str(o.id))
    # Wiping discovery tasks
    log.debug("Wiping discovery tasks")
    for j in [
            ManagedObject.BOX_DISCOVERY_JOB,
            ManagedObject.PERIODIC_DISCOVERY_JOB
    ]:
        Job.remove("discovery", j, key=o.id, pool=o.pool.name)
    # Wiping FM events
    log.debug("Wiping events")
    FailedEvent.objects.filter(managed_object=o.id).delete()
    ActiveEvent.objects.filter(managed_object=o.id).delete()
    ArchivedEvent.objects.filter(managed_object=o.id).delete()
    # Wiping alarms
    log.debug("Wiping alarms")
    for ac in (ActiveAlarm, ArchivedAlarm):
        for a in ac.objects.filter(managed_object=o.id):
            # Relink root causes
            my_root = a.root
            for iac in (ActiveAlarm, ArchivedAlarm):
                for ia in iac.objects.filter(root=a.id):
                    ia.root = my_root
                    ia.save()
            # Delete alarm
            a.delete()
    # Wiping MAC DB
    log.debug("Wiping MAC DB")
    MACDB._get_collection().remove({"managed_object": o.id})
    # Wiping discovery id cache
    log.debug("Wiping discovery id")
    DiscoveryID._get_collection().remove({"object": o.id})
    # Wiping interfaces, subs and links
    # Wipe links
    log.debug("Wiping links")
    for i in Interface.objects.filter(managed_object=o.id):
        # @todo: Remove aggregated links correctly
        Link.objects.filter(interfaces=i.id).delete()
    #
    log.debug("Wiping subinterfaces")
    SubInterface.objects.filter(managed_object=o.id).delete()
    log.debug("Wiping interfaces")
    Interface.objects.filter(managed_object=o.id).delete()
    log.debug("Wiping forwarding instances")
    ForwardingInstance.objects.filter(managed_object=o.id).delete()
    # Unbind from IPAM
    log.debug("Unbind from IPAM")
    for a in Address.objects.filter(managed_object=o):
        a.managed_object = None
        a.save()
    # Wipe object status
    log.debug("Wiping object status")
    ObjectStatus.objects.filter(object=o.id).delete()
    # Wipe outages
    log.debug("Wiping outages")
    Outage.objects.filter(object=o.id).delete()
    # Wipe uptimes
    log.debug("Wiping uptimes")
    Uptime.objects.filter(object=o.id).delete()
    # Wipe reboots
    log.debug("Wiping reboots")
    Reboot.objects.filter(object=o.id).delete()
    # Delete Managed Object's capabilities
    log.debug("Wiping capabilitites")
    ObjectCapabilities.objects.filter(object=o.id).delete()
    # Delete Managed Object's attributes
    log.debug("Wiping attributes")
    ManagedObjectAttribute.objects.filter(managed_object=o).delete()
    # Finally delete object and config
    log.debug("Finally wiping object")
    o.delete()
    log.debug("Done")
Esempio n. 21
0
class BaseLoader(object):
    """
    Import directory structure:
    var/
        import/
            <system name>/
                <loader name>/
                    import.jsonl[.ext]  -- state to load, must have .ext extension
                                         according to selected compressor
                    mappings.csv -- ID mappings
                    archive/
                        import-YYYY-MM-DD-HH-MM-SS.jsonl.ext -- imported state

    Import file format: CSV, unix end of lines, UTF-8, comma-separated
    First column - record id in the terms of connected system,
    other columns must be defined in *fields* variable.

    File must be sorted by first field either as string or as numbers,
    sort order must not be changed.

    mappings.csv - CSV, unix end of lines, UTF-8 comma separated
    mappings of ID between NOC and remote system. Populated by loader
    automatically.

    :param fields: List of either field names or tuple of
        (field name, related loader name)
    """

    # Loader name
    name: str
    # Loader model (Database)
    model = None
    # Data model
    data_model: BaseModel

    # List of tags to add to the created records
    tags = []

    rx_archive = re.compile(r"^import-\d{4}(?:-\d{2}){5}.jsonl%s$" %
                            compressor.ext.replace(".", r"\."))

    # Discard records which cannot be dereferenced
    discard_deferred = False
    # Ignore auto-generated unique fields
    ignore_unique = {"bi_id"}
    # Array fields need merge values
    incremental_change = {
        "labels", "static_client_groups", "static_service_groups"
    }
    # Workflow fields
    workflow_state_sync = False
    workflow_fields = {"state", "state_changed", "event"}
    workflow_event_model = False
    workflow_add_event = "seen"
    workflow_delete_event = "missed"

    REPORT_INTERVAL = 1000

    class Deferred(Exception):
        pass

    def __init__(self, chain):
        self.chain = chain
        self.system = chain.system
        self.logger = PrefixLoggerAdapter(
            logger, "%s][%s" % (self.system.name, self.name))
        self.disable_mappings = False
        self.import_dir = os.path.join(config.path.etl_import,
                                       self.system.name, self.name)
        self.archive_dir = os.path.join(self.import_dir, "archive")
        self.mappings_path = os.path.join(self.import_dir, "mappings.csv")
        self.mappings = {}
        self.wf_state_mappings = {}
        self.new_state_path = None
        self.c_add = 0
        self.c_change = 0
        self.c_delete = 0
        # Mapped fields
        self.mapped_fields = self.data_model.get_mapped_fields()
        # Build clean map
        self.clean_map = {}  # field name -> clean function
        self.pending_deletes: List[Tuple[str,
                                         BaseModel]] = []  # (id, BaseModel)
        self.referred_errors: List[Tuple[str,
                                         BaseModel]] = []  # (id, BaseModel)
        if self.is_document:
            import mongoengine.errors

            unique_fields = [
                f.name for f in self.model._fields.values()
                if f.unique and f.name not in self.ignore_unique
            ]
            self.integrity_exception = mongoengine.errors.NotUniqueError
        else:
            # Third-party modules
            import django.db.utils

            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
                and f.name not in self.ignore_unique
            ]
            self.integrity_exception = django.db.utils.IntegrityError
        if unique_fields:
            self.unique_field = unique_fields[0]
        else:
            self.unique_field = None
        self.has_remote_system: bool = hasattr(self.model, "remote_system")
        if self.workflow_state_sync:
            self.load_wf_state_mappings()

    @property
    def is_document(self):
        """
        Returns True if model is Document, False - if Model
        """
        return hasattr(self.model, "_fields")

    def load_mappings(self):
        """
        Load mappings file
        """
        if self.model:
            if self.is_document:
                self.update_document_clean_map()
            else:
                self.update_model_clean_map()
        if not os.path.exists(self.mappings_path):
            return
        self.logger.info("Loading mappings from %s", self.mappings_path)
        with open(self.mappings_path) as f:
            reader = csv.reader(f)
            for k, v in reader:
                self.mappings[self.clean_str(k)] = v
        self.logger.info("%d mappings restored", len(self.mappings))

    def load_wf_state_mappings(self):
        from noc.wf.models.state import State

        self.logger.info("Loading Workflow states")
        for ws in State.objects.filter():
            self.wf_state_mappings[(str(ws.workflow.id), ws.name)] = ws

    def get_new_state(self) -> Optional[TextIOWrapper]:
        """
        Returns file object of new state, or None when not present
        """
        # Try import.csv
        path = compressor.get_path(
            os.path.join(self.import_dir, "import.jsonl"))
        if not os.path.isfile(path):
            return None
        logger.info("Loading from %s", path)
        self.new_state_path = path
        return compressor(path, "r").open()

    def get_current_state(self) -> TextIOWrapper:
        """
        Returns file object of current state or None
        """
        self.load_mappings()
        if not os.path.isdir(self.archive_dir):
            self.logger.info("Creating archive directory: %s",
                             self.archive_dir)
            try:
                os.mkdir(self.archive_dir)
            except OSError as e:
                self.logger.error("Failed to create directory: %s (%s)",
                                  self.archive_dir, e)
                # @todo: Die
        if os.path.isdir(self.archive_dir):
            fn = list(
                sorted(f for f in os.listdir(self.archive_dir)
                       if self.rx_archive.match(f)))
        else:
            fn = []
        if not fn:
            return StringIO("")
        path = os.path.join(self.archive_dir, fn[-1])
        logger.info("Current state from %s", path)
        return compressor(path, "r").open()

    def iter_jsonl(
            self,
            f: TextIOWrapper,
            data_model: Optional[BaseModel] = None) -> Iterable[BaseModel]:
        """
        Iterate over JSONl stream and yield model instances
        :param f:
        :param data_model:
        :return:
        """
        dm = data_model or self.data_model
        for line in f:
            yield dm.parse_raw(line.replace("\\r", ""))

    def diff(
        self,
        old: Iterable[BaseModel],
        new: Iterable[BaseModel],
        include_fields: Set = None
    ) -> Iterable[Tuple[Optional[BaseModel], Optional[BaseModel]]]:
        """
        Compare old and new CSV files and yield pair of matches
        * old, new -- when changed
        * old, None -- when removed
        * None, new -- when added
        """

        o = next(old, None)
        n = next(new, None)
        while o or n:
            if not o:
                # New
                yield None, n
                n = next(new, None)
            elif not n:
                # Removed
                yield o, None
                o = next(old, None)
            else:
                if n.id == o.id:
                    # Changed
                    if n.dict(include=include_fields) != o.dict(
                            include=include_fields):
                        yield o, n
                    n = next(new, None)
                    o = next(old, None)
                elif n.id < o.id:
                    # Added
                    yield None, n
                    n = next(new, None)
                else:
                    # Removed
                    yield o, None
                    o = next(old, None)

    def load(self):
        """
        Import new data
        """
        self.logger.info("Importing")
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            self.load_mappings()
            return
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        deferred_add = []
        deferred_change = []
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                try:
                    self.on_add(n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_add += [n]
            elif o and n is None:
                self.on_delete(o)
            else:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_change += [(o, n)]
            rn = self.c_add + self.c_change + self.c_delete
            if rn > 0 and rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Add deferred records
        while len(deferred_add):
            nd = []
            for row in deferred_add:
                try:
                    self.on_add(row)
                except self.Deferred:
                    nd += [row]
            if len(nd) == len(deferred_add):
                raise Exception("Unable to defer references")
            deferred_add = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Change deferred records
        while len(deferred_change):
            nd = []
            for o, n in deferred_change:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    nd += [(o, n)]
            if len(nd) == len(deferred_change):
                raise Exception("Unable to defer references")
            deferred_change = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)

    def find_object(self, v: Dict[str, Any]) -> Optional[Any]:
        """
        Find object by remote system/remote id
        :param v:
        :return:
        """
        self.logger.debug("Find object: %s", v)
        if not self.has_remote_system:
            return None
        if not v.get("remote_system") or not v.get("remote_id"):
            self.logger.warning("RS or RID not found")
            return None
        find_query = {
            "remote_system": v.get("remote_system"),
            "remote_id": v.get("remote_id")
        }
        try:
            return self.model.objects.get(**find_query)
        except self.model.MultipleObjectsReturned:
            if self.unique_field:
                find_query[self.unique_field] = v.get(self.unique_field)
                r = self.model.objects.filter(**find_query)
                if not r:
                    r = self.model.objects.filter(**find_query)
                return list(r)[-1]
            raise self.model.MultipleObjectsReturned
        except self.model.DoesNotExist:
            self.logger.debug("Object not found")
            return None

    def create_object(self, v):
        """
        Create object with attributes. Override to save complex
        data structures
        """
        self.logger.debug("Create object")
        o = self.model(**v)
        try:
            o.save()
        except self.integrity_exception as e:
            self.logger.warning("Integrity error: %s", e)
            assert self.unique_field
            if not self.is_document:
                from django.db import connection

                connection._rollback()
            # Fallback to change object
            o = self.model.objects.get(
                **{self.unique_field: v[self.unique_field]})
            for k, nv in v.items():
                setattr(o, k, nv)
            o.save()
        return o

    def change_object(self,
                      object_id: str,
                      v: Dict[str, Any],
                      inc_changes: Dict[str, Dict[str, List]] = None):
        """
        Change object with attributes
        """
        self.logger.debug("Changed object: %s", v)
        # See: https://code.getnoc.com/noc/noc/merge_requests/49
        try:
            o = self.model.objects.get(pk=object_id)
        except self.model.DoesNotExist:
            self.logger.error("Cannot change %s:%s: Does not exists",
                              self.name, object_id)
            return None
        for k, nv in v.items():
            if inc_changes and k in inc_changes:
                ov = getattr(o, k, [])
                nv = list(
                    set(ov).union(set(inc_changes[k]["add"])) -
                    set(inc_changes[k]["remove"]))
            setattr(o, k, nv)
        o.save()
        return o

    def on_add(self, item: BaseModel) -> None:
        """
        Create new record
        """
        self.logger.debug("Add: %s", item.json())
        v = self.clean(item)
        if "id" in v:
            del v["id"]
        for fn in set(v).intersection(self.workflow_fields):
            del v[fn]
        o = self.find_object(v)
        if o:
            self.c_change += 1
            # Lost&found object with same remote_id
            self.logger.debug("Lost and Found object")
            vv = {
                "remote_system": v["remote_system"],
                "remote_id": v["remote_id"]
            }
            for fn, nv in v.items():
                if fn in vv:
                    continue
                if getattr(o, fn) != nv:
                    vv[fn] = nv
            o = self.change_object(o.id, vv)
        else:
            self.c_add += 1
            o = self.create_object(v)
            if self.workflow_event_model:
                o.fire_event(self.workflow_add_event)
        if self.workflow_state_sync:
            self.change_workflow(o, getattr(item, "state", None),
                                 getattr(item, "state_changed", None))
        self.set_mappings(item.id, o.id)

    def on_change(self, o: BaseModel, n: BaseModel):
        """
        Create change record
        """
        self.logger.debug("Change: %s", n.json())
        self.c_change += 1
        nv = self.clean(n)
        changes = {
            "remote_system": nv["remote_system"],
            "remote_id": nv["remote_id"]
        }
        incremental_changes = {}
        ov = self.clean(o)
        for fn in self.data_model.__fields__:
            if fn == "id" or fn in self.workflow_fields:
                continue
            if ov[fn] != nv[fn]:
                self.logger.debug("   %s: %s -> %s", fn, ov[fn], nv[fn])
                changes[fn] = nv[fn]
                if fn in self.incremental_change:
                    incremental_changes[fn] = {
                        "add": list(set(nv[fn]) - set(ov[fn])),
                        "remove": list(set(ov[fn]) - set(nv[fn])),
                    }
        if n.id in self.mappings:
            o = self.change_object(self.mappings[n.id],
                                   changes,
                                   inc_changes=incremental_changes)
            if self.workflow_state_sync:
                self.change_workflow(o, getattr(n, "state", None),
                                     getattr(n, "state_changed", None))
        else:
            self.logger.error("Cannot map id '%s'. Skipping.", n.id)

    def on_delete(self, item: BaseModel):
        """
        Delete record
        """
        self.pending_deletes += [(item.id, item)]

    def change_workflow(self,
                        o,
                        state: str,
                        changed_date: Optional[datetime.datetime] = None):
        state = self.clean_wf_state(o.profile.workflow, state)
        if state and o.state != state:
            self.logger.debug("Change workflow state: %s -> %s", o.state,
                              state)
            o.set_state(state, changed_date)

    def purge(self):
        """
        Perform pending deletes
        """
        for r_id, msg in reversed(self.pending_deletes):
            self.logger.debug("Delete: %s", msg)
            self.c_delete += 1
            try:
                obj = self.model.objects.get(pk=self.mappings[r_id])
                if self.workflow_event_model:
                    obj.fire_event(self.workflow_delete_event)
                else:
                    obj.delete()
            except ValueError as e:  # Referred Error
                self.logger.error("%s", str(e))
                self.referred_errors += [(r_id, msg)]
            except KeyError as e:
                # Undefined mappings
                self.logger.error("%s", str(e))
            except self.model.DoesNotExist:
                pass  # Already deleted
        self.pending_deletes = []

    def save_state(self):
        """
        Save current state
        """
        if not self.new_state_path:
            return
        self.logger.info("Summary: %d new, %d changed, %d removed", self.c_add,
                         self.c_change, self.c_delete)
        self.logger.info("Error delete by referred: %s",
                         "\n".join(b.json() for _, b in self.referred_errors))
        t = time.localtime()
        archive_path = os.path.join(
            self.archive_dir,
            compressor.get_path("import-%04d-%02d-%02d-%02d-%02d-%02d.jsonl" %
                                tuple(t[:6])),
        )
        self.logger.info("Moving %s to %s", self.new_state_path, archive_path)
        if self.new_state_path.endswith(compressor.ext):
            # Simply move the file
            shutil.move(self.new_state_path, archive_path)
        else:
            # Compress the file
            self.logger.info("Compressing")
            with open(self.new_state_path,
                      "r") as s, compressor(archive_path, "w") as d:
                d.write(s.read())
            os.unlink(self.new_state_path)
        self.logger.info("Saving mappings to %s", self.mappings_path)
        mdata = "\n".join("%s,%s" % (k, self.mappings[k])
                          for k in sorted(self.mappings))
        safe_rewrite(self.mappings_path, mdata)

    def clean(self, item: BaseModel) -> Dict[str, Any]:
        """
        Cleanup row and return a dict of field name -> value
        """
        r = {
            k: self.clean_map.get(k, self.clean_any)(v)
            for k, v in item.dict().items()
        }
        # Fill integration fields
        r["remote_system"] = self.system.remote_system
        r["remote_id"] = self.clean_str(item.id)
        return r

    def clean_any(self, value: Any) -> Any:
        return value

    def clean_str(self, value) -> Optional[str]:
        if value:
            if isinstance(value, str):
                return smart_text(value)
            elif not isinstance(value, str):
                return str(value)
            else:
                return value
        else:
            return None

    def clean_map_str(self, mappings, value):
        value = self.clean_str(value)
        if self.disable_mappings and not mappings:
            return value
        elif value:
            try:
                value = mappings[value]
            except KeyError:
                self.logger.warning("Deferred. Unknown map value: %s", value)
                raise self.Deferred
        return value

    def clean_bool(self, value: str) -> Optional[bool]:
        if value == "" or value is None:
            return None
        try:
            return int(value) != 0
        except ValueError:
            pass
        value = value.lower()
        return value in ("t", "true", "y", "yes")

    def clean_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = mappings[value]
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def clean_int_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = int(mappings[value])
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def clean_wf_state(self, workflow, state: str):
        if not state:
            return None
        try:
            return self.wf_state_mappings[(str(workflow.id), state)]
        except KeyError:
            self.logger.error("Unknown Workflow state value %s:%s", workflow,
                              state)
            raise ValueError(
                f"Unknown Workflow state value {workflow}:{state}", workflow,
                state)

    def set_mappings(self, rv, lv):
        self.logger.debug("Set mapping remote: %s, local: %s", rv, lv)
        self.mappings[str(rv)] = str(lv)

    def update_document_clean_map(self):
        from mongoengine.fields import BooleanField, ReferenceField
        from noc.core.mongo.fields import PlainReferenceField, ForeignKeyField

        self.logger.debug("Update Document clean map")
        for fn, ft in self.model._fields.items():
            if fn not in self.data_model.__fields__:
                continue
            if isinstance(ft, BooleanField):
                self.clean_map[fn] = self.clean_bool
            elif isinstance(ft, (PlainReferenceField, ReferenceField)):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif isinstance(ft, ForeignKeyField):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_int_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif fn in self.mapped_fields:
                self.clean_map[fn] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[fn]))

    def update_model_clean_map(self):
        from django.db.models import BooleanField, ForeignKey
        from noc.core.model.fields import DocumentReferenceField

        self.logger.debug("Update Model clean map")
        for f in self.model._meta.fields:
            if f.name not in self.data_model.__fields__:
                continue
            if isinstance(f, BooleanField):
                self.clean_map[f.name] = self.clean_bool
            elif isinstance(f, DocumentReferenceField):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.document,
                    )
            elif isinstance(f, ForeignKey):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.remote_field.model,
                    )
            elif f.name in self.mapped_fields:
                self.clean_map[f.name] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[f.name]))

    def check(self, chain):
        self.logger.info("Checking")
        # Get constraints
        if self.is_document:
            # Document
            required_fields = [
                f.name for f in self.model._fields.values()
                if f.required or f.unique
            ]
            unique_fields = [
                f.name for f in self.model._fields.values() if f.unique
            ]
        else:
            # Model
            required_fields = [
                f.name for f in self.model._meta.fields if not f.blank
            ]
            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
            ]
        if not required_fields and not unique_fields:
            self.logger.info("Nothing to check, skipping")
            return 0
        self.logger.debug("[%s] Required fields: %s", self.model,
                          required_fields)
        self.logger.debug("[%s] Unique fields: %s", self.model, unique_fields)
        self.logger.debug("[%s] Mapped fields: %s", self.model,
                          self.mapped_fields)
        # Prepare data
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            return 0
        new_state = self.iter_jsonl(ns)
        uv = set()
        m_data = {}  # field_number -> set of mapped ids
        # Load mapped ids
        for f in self.mapped_fields:
            line = chain.get_loader(self.mapped_fields[f])
            ls = line.get_new_state()
            if not ls:
                ls = line.get_current_state()
            ms = self.iter_jsonl(ls, data_model=line.data_model)
            m_data[self.data_model.__fields__[f].name] = set(row.id
                                                             for row in ms)
        # Process data
        n_errors = 0
        for row in new_state:
            row = row.dict()
            lr = len(row)
            # Check required fields
            for f in required_fields:
                if f not in self.data_model.__fields__:
                    continue
                if f not in row:
                    self.logger.error(
                        "ERROR: Required field #(%s) is missed in row: %s",
                        f,
                        # self.fields[i],
                        row,
                    )
                    n_errors += 1
                    continue
            # Check unique fields
            for f in unique_fields:
                if f in self.ignore_unique:
                    continue
                v = row[f]
                if v in uv:
                    self.logger.error(
                        "ERROR: Field #(%s) value is not unique: %s",
                        f,
                        # self.fields[i],
                        row,
                    )
                    n_errors += 1
                else:
                    uv.add(v)
            # Check mapped fields
            for i, f in enumerate(self.mapped_fields):
                if i >= lr:
                    continue
                v = row[f]
                if v and v not in m_data[f]:
                    self.logger.error(
                        "ERROR: Field #%d(%s) == '%s' refers to non-existent record: %s",
                        i,
                        f,
                        row[f],
                        row,
                    )
                    n_errors += 1
        if n_errors:
            self.logger.info("%d errors found", n_errors)
        else:
            self.logger.info("No errors found")
        return n_errors

    def check_diff(self):
        def dump(cmd, row):
            print("%s %s" % (cmd, row.json()))

        print("--- %s.%s" % (self.chain.system.name, self.name))
        ns = self.get_new_state()
        if not ns:
            return
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                dump("+", n)
            elif o and n is None:
                dump("-", o)
            else:
                dump("/", o)
                dump("\\", n)

    def check_diff_summary(self):
        i, u, d = 0, 0, 0
        ns = self.get_new_state()
        if not ns:
            return i, u, d
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                i += 1
            elif o and n is None:
                d += 1
            else:
                u += 1
        return i, u, d
Esempio n. 22
0
File: base.py Progetto: nbashev/noc
class BaseRemoteSystem(object):
    extractors = {}

    extractors_order = [
        "admdiv",
        "networksegmentprofile",
        "networksegment",
        "object",
        "container",
        "resourcegroup",
        "managedobjectprofile",
        "administrativedomain",
        "authprofile",
        "ttsystem",
        "project",
        "managedobject",
        "link",
        "subscriberprofile",
        "subscriber",
        "serviceprofile",
        "service",
    ]

    def __init__(self, remote_system):
        self.remote_system = remote_system
        self.name = remote_system.name
        self.config = self.remote_system.config
        self.logger = PrefixLoggerAdapter(logger, self.name)

    def get_loader_chain(self):
        from noc.core.etl.loader.chain import LoaderChain

        chain = LoaderChain(self)
        for ld in self.extractors_order:
            chain.get_loader(ld)
        return chain

    def extract(self, extractors=None):
        extractors = extractors or []
        for en in reversed(self.extractors_order):
            if extractors and en not in extractors:
                self.logger.info("Skipping extractor %s", en)
                continue
            if en not in self.extractors:
                self.logger.info("Extractor %s is not implemented. Skipping",
                                 en)
                continue
            # @todo: Config
            xc = self.extractors[en](self)
            xc.extract()

    def load(self, loaders=None):
        loaders = loaders or []
        # Build chain
        chain = self.get_loader_chain()
        # Add & Modify
        for ll in chain:
            if loaders and ll.name not in loaders:
                ll.load_mappings()
                continue
            ll.load()
            ll.save_state()
        # Remove in reverse order
        for ll in reversed(list(chain)):
            ll.purge()

    def check(self, out):
        chain = self.get_loader_chain()
        # Check
        summary = []
        n_errors = 0
        for ll in chain:
            n = ll.check(chain)
            if n:
                ss = "%d errors" % n
            else:
                ss = "OK"
            summary += ["%s.%s: %s" % (self.name, ll.name, ss)]
            n_errors += n
        if summary:
            out.write("Summary:\n")
            out.write("\n".join(summary) + "\n")
        return n_errors

    @classmethod
    def extractor(cls, c):
        """
        Decorator for extractor
        :return:
        """
        cls.extractors[c.name] = c
        return c
Esempio n. 23
0
File: base.py Progetto: skripkar/noc
class MMLBase(object):
    name = "mml"
    iostream_class = None
    default_port = None
    BUFFER_SIZE = config.activator.buffer_size
    MATCH_TAIL = 256
    # Retries on immediate disconnect
    CONNECT_RETRIES = config.activator.connect_retries
    # Timeout after immediate disconnect
    CONNECT_TIMEOUT = config.activator.connect_timeout
    # compiled capabilities
    HAS_TCP_KEEPALIVE = hasattr(socket, "SO_KEEPALIVE")
    HAS_TCP_KEEPIDLE = hasattr(socket, "TCP_KEEPIDLE")
    HAS_TCP_KEEPINTVL = hasattr(socket, "TCP_KEEPINTVL")
    HAS_TCP_KEEPCNT = hasattr(socket, "TCP_KEEPCNT")
    HAS_TCP_NODELAY = hasattr(socket, "TCP_NODELAY")
    # Time until sending first keepalive probe
    KEEP_IDLE = 10
    # Keepalive packets interval
    KEEP_INTVL = 10
    # Terminate connection after N keepalive failures
    KEEP_CNT = 3

    def __init__(self, script, tos=None):
        self.script = script
        self.profile = script.profile
        self.logger = PrefixLoggerAdapter(self.script.logger, self.name)
        self.iostream = None
        self.ioloop = None
        self.command = None
        self.buffer = ""
        self.is_started = False
        self.result = None
        self.error = None
        self.is_closed = False
        self.close_timeout = None
        self.current_timeout = None
        self.tos = tos
        self.rx_mml_end = re.compile(self.script.profile.pattern_mml_end, re.MULTILINE)
        if self.script.profile.pattern_mml_continue:
            self.rx_mml_continue = re.compile(self.script.profile.pattern_mml_continue, re.MULTILINE)
        else:
            self.rx_mml_continue = None

    def close(self):
        self.script.close_current_session()
        self.close_iostream()
        if self.ioloop:
            self.logger.debug("Closing IOLoop")
            self.ioloop.close(all_fds=True)
            self.ioloop = None
        self.is_closed = True

    def close_iostream(self):
        if self.iostream:
            self.iostream.close()

    def deferred_close(self, session_timeout):
        if self.is_closed or not self.iostream:
            return
        self.logger.debug("Setting close timeout to %ss",
                          session_timeout)
        # Cannot call call_later directly due to
        # thread-safety problems
        # See tornado issue #1773
        tornado.ioloop.IOLoop.instance().add_callback(
            self._set_close_timeout,
            session_timeout
        )

    def _set_close_timeout(self, session_timeout):
        """
        Wrapper to deal with IOLoop.add_timeout thread safety problem
        :param session_timeout:
        :return:
        """
        self.close_timeout = tornado.ioloop.IOLoop.instance().call_later(
            session_timeout,
            self.close
        )

    def create_iostream(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if self.tos:
            s.setsockopt(
                socket.IPPROTO_IP, socket.IP_TOS, self.tos
            )
        if self.HAS_TCP_NODELAY:
            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        if self.HAS_TCP_KEEPALIVE:
            s.setsockopt(
                socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1
            )
            if self.HAS_TCP_KEEPIDLE:
                s.setsockopt(socket.SOL_TCP,
                             socket.TCP_KEEPIDLE, self.KEEP_IDLE)
            if self.HAS_TCP_KEEPINTVL:
                s.setsockopt(socket.SOL_TCP,
                             socket.TCP_KEEPINTVL, self.KEEP_INTVL)
            if self.HAS_TCP_KEEPCNT:
                s.setsockopt(socket.SOL_TCP,
                             socket.TCP_KEEPCNT, self.KEEP_CNT)
        return self.iostream_class(s, self)

    def set_timeout(self, timeout):
        if timeout:
            self.logger.debug("Setting timeout: %ss", timeout)
            self.current_timeout = datetime.timedelta(seconds=timeout)
        else:
            if self.current_timeout:
                self.logger.debug("Resetting timeouts")
            self.current_timeout = None

    def set_script(self, script):
        self.script = script
        if self.close_timeout:
            tornado.ioloop.IOLoop.instance().remove_timeout(self.close_timeout)
            self.close_timeout = None

    @tornado.gen.coroutine
    def send(self, cmd):
        # @todo: Apply encoding
        cmd = str(cmd)
        self.logger.debug("Send: %r", cmd)
        yield self.iostream.write(cmd)

    @tornado.gen.coroutine
    def submit(self):
        # Create iostream and connect, when necessary
        if not self.iostream:
            self.iostream = self.create_iostream()
            address = (
                self.script.credentials.get("address"),
                self.script.credentials.get("cli_port", self.default_port)
            )
            self.logger.debug("Connecting %s", address)
            try:
                yield self.iostream.connect(address)
            except tornado.iostream.StreamClosedError:
                self.logger.debug("Connection refused")
                self.error = MMLConnectionRefused("Connection refused")
                raise tornado.gen.Return(None)
            self.logger.debug("Connected")
            yield self.iostream.startup()
        # Perform all necessary login procedures
        if not self.is_started:
            self.is_started = True
            yield self.send(self.profile.get_mml_login(self.script))
            yield self.get_mml_response()
            if self.error:
                self.error = MMLAuthFailed(str(self.error))
                raise tornado.gen.Return(None)
        # Send command
        yield self.send(self.command)
        r = yield self.get_mml_response()
        raise tornado.gen.Return(r)

    @tornado.gen.coroutine
    def get_mml_response(self):
        result = []
        header_sep = self.profile.mml_header_separator
        while True:
            r = yield self.read_until_end()
            r = r.strip()
            # Process header
            if header_sep not in r:
                self.result = ""
                self.error = MMLBadResponse("Missed header separator")
                raise tornado.gen.Return(None)
            header, r = r.split(header_sep, 1)
            code, msg = self.profile.parse_mml_header(header)
            if code:
                # MML Error
                self.result = ""
                self.error = MMLError("%s (code=%s)" % (msg, code))
                raise tornado.gen.Return(None)
            # Process continuation
            if self.rx_mml_continue:
                # Process continued block
                offset = max(0, len(r) - self.MATCH_TAIL)
                match = self.rx_mml_continue.search(r, offset)
                if match:
                    self.logger.debug("Continuing in the next block")
                    result += [r[:match.start()]]
                    continue
            result += [r]
            break
        self.result = "".join(result)
        raise tornado.gen.Return(self.result)

    def execute(self, cmd, **kwargs):
        """
        Perform command and return result
        :param cmd:
        :param kwargs:
        :return:
        """
        if self.close_timeout:
            self.logger.debug("Removing close timeout")
            self.ioloop.remove_timeout(self.close_timeout)
            self.close_timeout = None
        self.buffer = ""
        self.command = self.profile.get_mml_command(cmd, **kwargs)
        self.error = None
        if not self.ioloop:
            self.logger.debug("Creating IOLoop")
            self.ioloop = tornado.ioloop.IOLoop()
        with Span(server=self.script.credentials.get("address"),
                  service=self.name, in_label=self.command) as s:
            self.ioloop.run_sync(self.submit)
            if self.error:
                if s:
                    s.error_text = str(self.error)
                raise self.error
            else:
                return self.result

    @tornado.gen.coroutine
    def read_until_end(self):
        connect_retries = self.CONNECT_RETRIES
        while True:
            try:
                f = self.iostream.read_bytes(self.BUFFER_SIZE,
                                             partial=True)
                if self.current_timeout:
                    r = yield tornado.gen.with_timeout(
                        self.current_timeout,
                        f
                    )
                else:
                    r = yield f
            except tornado.iostream.StreamClosedError:
                # Check if remote end closes connection just
                # after connection established
                if not self.is_started and connect_retries:
                    self.logger.info(
                        "Connection reset. %d retries left. Waiting %d seconds",
                        connect_retries, self.CONNECT_TIMEOUT
                    )
                    while connect_retries:
                        yield tornado.gen.sleep(self.CONNECT_TIMEOUT)
                        connect_retries -= 1
                        self.iostream = self.create_iostream()
                        address = (
                            self.script.credentials.get("address"),
                            self.script.credentials.get("cli_port", self.default_port)
                        )
                        self.logger.debug("Connecting %s", address)
                        try:
                            yield self.iostream.connect(address)
                            break
                        except tornado.iostream.StreamClosedError:
                            if not connect_retries:
                                raise tornado.iostream.StreamClosedError()
                    continue
                else:
                    raise tornado.iostream.StreamClosedError()
            except tornado.gen.TimeoutError:
                self.logger.info("Timeout error")
                raise tornado.gen.TimeoutError("Timeout")
            self.logger.debug("Received: %r", r)
            self.buffer += r
            offset = max(0, len(self.buffer) - self.MATCH_TAIL)
            match = self.rx_mml_end.search(self.buffer, offset)
            if match:
                self.logger.debug("End of the block")
                r = self.buffer[:match.start()]
                self.buffer = self.buffer[match.end()]
                raise tornado.gen.Return(r)

    def shutdown_session(self):
        if self.profile.shutdown_session:
            self.logger.debug("Shutdown session")
            self.profile.shutdown_session(self.script)