Exemple #1
0
class BaseExtractor(object):
    """
    Data extractor interface. Subclasses must provide
    *iter_data* method
    """

    Problem = namedtuple("Problem",
                         ["line", "is_rej", "p_class", "message", "row"])
    name = None
    PREFIX = config.path.etl_import
    REPORT_INTERVAL = 1000
    # List of rows to be used as constant data
    data = []
    # Suppress deduplication message
    suppress_deduplication_log = False

    def __init__(self, system):
        self.system = system
        self.config = system.config
        self.logger = PrefixLoggerAdapter(logger,
                                          "%s][%s" % (system.name, self.name))
        self.import_dir = os.path.join(self.PREFIX, system.name, self.name)
        self.fatal_problems = []
        self.quality_problems = []

    def register_quality_problem(self, line, p_class, message, row):
        self.quality_problems += [
            self.Problem(line=line + 1,
                         is_rej=False,
                         p_class=p_class,
                         message=message,
                         row=row)
        ]

    def register_fatal_problem(self, line, p_class, message, row):
        self.fatal_problems += [
            self.Problem(line=line + 1,
                         is_rej=True,
                         p_class=p_class,
                         message=message,
                         row=row)
        ]

    def get_new_state(self):
        if not os.path.isdir(self.import_dir):
            self.logger.info("Creating directory %s", self.import_dir)
            os.makedirs(self.import_dir)
        path = os.path.join(self.import_dir, "import.csv.gz")
        self.logger.info("Writing to %s", path)
        return gzip.GzipFile(path, "w")

    def get_problem_file(self):
        if not os.path.isdir(self.import_dir):
            self.logger.info("Creating directory %s", self.import_dir)
            os.makedirs(self.import_dir)
        path = os.path.join(self.import_dir, "import.csv.rej.gz")
        self.logger.info("Writing to %s", path)
        return gzip.GzipFile(path, "w")

    def iter_data(self):
        for row in self.data:
            yield row

    def filter(self, row):
        return True

    def clean(self, row):
        return row

    def extract(self):
        def q(s):
            if s == "" or s is None:
                return ""
            elif isinstance(s, six.text_type):
                return s.encode("utf-8")
            else:
                return str(s)

        # Fetch data
        self.logger.info("Extracting %s from %s", self.name, self.system.name)
        t0 = perf_counter()
        data = []
        n = 0
        seen = set()
        for row in self.iter_data():
            if not self.filter(row):
                continue
            row = self.clean(row)
            if row[0] in seen:
                if not self.suppress_deduplication_log:
                    self.logger.error("Duplicated row truncated: %r", row)
                continue
            else:
                seen.add(row[0])
            data += [[q(x) for x in row]]
            n += 1
            if n % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", n)
        dt = perf_counter() - t0
        speed = n / dt
        self.logger.info("%d records extracted in %.2fs (%d records/s)", n, dt,
                         speed)
        # Sort
        data.sort()
        # Write
        f = self.get_new_state()
        writer = csv.writer(f)
        writer.writerows(data)
        f.close()
        if self.fatal_problems or self.quality_problems:
            self.logger.warning(
                "Detect problems on extracting, fatal: %d, quality: %d",
                len(self.fatal_problems),
                len(self.quality_problems),
            )
            self.logger.warning("Line num\tType\tProblem string")
            for p in self.fatal_problems:
                self.logger.warning(
                    "Fatal problem, line was rejected: %s\t%s\t%s" %
                    (p.line, p.p_class, p.message))
            for p in self.quality_problems:
                self.logger.warning(
                    "Data quality problem in line:  %s\t%s\t%s" %
                    (p.line, p.p_class, p.message))
            # Dump problem to file
            try:
                f = self.get_problem_file()
                writer = csv.writer(f, delimiter=";")
                for p in itertools.chain(self.quality_problems,
                                         self.fatal_problems):
                    writer.writerow([str(c).encode("utf-8") for c in p.row] + [
                        "Fatal problem, line was rejected" if p.
                        is_rej else "Data quality problem"
                    ] + [p.message.encode("utf-8")])
            except IOError as e:
                self.logger.error("Error when saved problems %s", e)
            finally:
                f.close()
        else:
            self.logger.info("No problems detected")
Exemple #2
0
class BaseLoader(object):
    """
    Import directory structure:
    var/
        import/
            <system name>/
                <loader name>/
                    import.csv[.gz]  -- state to load, can have .gz extension
                    mappings.csv -- ID mappings
                    archive/
                        import-YYYY-MM-DD-HH-MM-SS.csv.gz -- imported state

    Import file format: CSV, unix end of lines, UTF-8, comma-separated
    First column - record id in the terms of connected system,
    other columns must be defined in *fields* variable.

    File must be sorted by first field either as string or as numbers,
    sort order must not be changed.

    mappings.csv - CSV, unix end of lines, UTF-8 comma separated
    mappings of ID between NOC and remote system. Populated by loader
    automatically.

    :param fields: List of either field names or tuple of
        (field name, related loader name)
    """
    # Loader name
    name = None
    # Loader model
    model = None
    # Mapped fields
    mapped_fields = {}

    fields = []

    # List of tags to add to the created records
    tags = []

    PREFIX = config.path.etl_import
    rx_archive = re.compile("^import-\d{4}(?:-\d{2}){5}.csv.gz$")

    # Discard records which cannot be dereferenced
    discard_deferred = False
    # Ignore auto-generated unique fields
    ignore_unique = set(["bi_id"])

    REPORT_INTERVAL = 1000

    class Deferred(Exception):
        pass

    def __init__(self, chain):
        self.chain = chain
        self.system = chain.system
        self.logger = PrefixLoggerAdapter(
            logger, "%s][%s" % (self.system.name, self.name))
        self.import_dir = os.path.join(self.PREFIX, self.system.name,
                                       self.name)
        self.archive_dir = os.path.join(self.import_dir, "archive")
        self.mappings_path = os.path.join(self.import_dir, "mappings.csv")
        self.mappings = {}
        self.new_state_path = None
        self.c_add = 0
        self.c_change = 0
        self.c_delete = 0
        # Build clean map
        self.clean_map = dict(
            (n, self.clean_str)
            for n in self.fields)  # field name -> clean function
        self.pending_deletes = []  # (id, string)
        self.reffered_errors = []  # (id, string)
        if self.is_document:
            import mongoengine.errors
            unique_fields = [
                f.name for f in self.model._fields.itervalues()
                if f.unique and f.name not in self.ignore_unique
            ]
            self.integrity_exception = mongoengine.errors.NotUniqueError
        else:
            # Third-party modules
            import django.db.utils

            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
                and f.name not in self.ignore_unique
            ]
            self.integrity_exception = django.db.utils.IntegrityError
        if unique_fields:
            self.unique_field = unique_fields[0]
        else:
            self.unique_field = None

    @property
    def is_document(self):
        """
        Returns True if model is Document, False - if Model
        """
        return hasattr(self.model, "_fields")

    def load_mappings(self):
        """
        Load mappings file
        """
        if self.model:
            if self.is_document:
                self.update_document_clean_map()
            else:
                self.update_model_clean_map()
        if not os.path.exists(self.mappings_path):
            return
        self.logger.info("Loading mappings from %s", self.mappings_path)
        with open(self.mappings_path) as f:
            reader = csv.reader(f)
            for k, v in reader:
                self.mappings[k] = v
        self.logger.info("%d mappings restored", len(self.mappings))

    def get_new_state(self):
        """
        Returns file object of new state, or None when not present
        """
        # Try import.csv
        path = os.path.join(self.import_dir, "import.csv")
        if os.path.isfile(path):
            logger.info("Loading from %s", path)
            self.new_state_path = path
            return open(path, "r")
        # Try import.csv.gz
        path += ".gz"
        if os.path.isfile(path):
            logger.info("Loading from %s", path)
            self.new_state_path = path
            return gzip.GzipFile(path, "r")
        # No data to import
        return None

    def get_current_state(self):
        """
        Returns file object of current state or None
        """
        self.load_mappings()
        if not os.path.isdir(self.archive_dir):
            self.logger.info("Creating archive directory: %s",
                             self.archive_dir)
            try:
                os.mkdir(self.archive_dir)
            except OSError as e:
                self.logger.error("Failed to create directory: %s (%s)",
                                  self.archive_dir, e)
                # @todo: Die
        if os.path.isdir(self.archive_dir):
            fn = sorted(f for f in os.listdir(self.archive_dir)
                        if self.rx_archive.match(f))
        else:
            fn = []
        if fn:
            path = os.path.join(self.archive_dir, fn[-1])
            logger.info("Current state from %s", path)
            return gzip.GzipFile(path, "r")
        # No current state
        return six.StringIO("")

    def diff(self, old, new):
        """
        Compare old and new CSV files and yield pair of matches
        * old, new -- when changed
        * old, None -- when removed
        * None, new -- when added
        """
        def getnext(g):
            try:
                return next(g)
            except StopIteration:
                return None

        o = getnext(old)
        n = getnext(new)
        while o or n:
            if not o:
                # New
                yield None, n
                n = getnext(new)
            elif not n:
                # Removed
                yield o, None
                o = getnext(old)
            else:
                if n[0] == o[0]:
                    # Changed
                    if n != o:
                        yield o, n
                    n = getnext(new)
                    o = getnext(old)
                elif n[0] < o[0]:
                    # Added
                    yield None, n
                    n = getnext(new)
                else:
                    # Removed
                    yield o, None
                    o = getnext(old)

    def load(self):
        """
        Import new data
        """
        self.logger.info("Importing")
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            self.load_mappings()
            return
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        deferred_add = []
        deferred_change = []
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                try:
                    self.on_add(n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_add += [n]
            elif o and n is None:
                self.on_delete(o)
            else:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_change += [(o, n)]
            rn = self.c_add + self.c_change + self.c_delete
            if rn > 0 and rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Add deferred records
        while len(deferred_add):
            nd = []
            for row in deferred_add:
                try:
                    self.on_add(row)
                except self.Deferred:
                    nd += [row]
            if len(nd) == len(deferred_add):
                raise Exception("Unable to defer references")
            deferred_add = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Change deferred records
        while len(deferred_change):
            nd = []
            for o, n in deferred_change:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    nd += [(o, n)]
            if len(nd) == len(deferred_change):
                raise Exception("Unable to defer references")
            deferred_change = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)

    def find_object(self, v):
        """
        Find object by remote system/remote id
        :param v:
        :return:
        """
        if not v.get("remote_system") or not v.get("remote_id"):
            self.logger.warning("RS or RID not found")
            return None
        find_query = {
            "remote_system": v.get("remote_system"),
            "remote_id": v.get("remote_id")
        }
        try:
            return self.model.objects.get(**find_query)
        except self.model.MultipleObjectsReturned:
            if self.unique_field:
                find_query[self.unique_field] = v.get(self.unique_field)
                r = self.model.objects.filter(**find_query)
                if not r:
                    r = self.model.objects.filter(**find_query)
                return list(r)[-1]
            raise self.model.MultipleObjectsReturned
        except self.model.DoesNotExist:
            return None

    def create_object(self, v):
        """
        Create object with attributes. Override to save complex
        data structures
        """
        for k, nv in six.iteritems(v):
            if k == "tags":
                # Merge tags
                nv = sorted("%s:%s" % (self.system.name, x) for x in nv)
                v[k] = nv
        o = self.model(**v)
        try:
            o.save()
        except self.integrity_exception as e:
            self.logger.warning("Integrity error: %s", e)
            assert self.unique_field
            if not self.is_document:
                from django.db import connection
                connection._rollback()
            # Fallback to change object
            o = self.model.objects.get(
                **{self.unique_field: v[self.unique_field]})
            for k, nv in six.iteritems(v):
                setattr(o, k, nv)
            o.save()
        return o

    def change_object(self, object_id, v):
        """
        Change object with attributes
        """
        self.logger.debug("Changed object")
        # See: https://code.getnoc.com/noc/noc/merge_requests/49
        try:
            o = self.model.objects.get(pk=object_id)
        except self.model.DoesNotExist:
            self.logger.error("Cannot change %s:%s: Does not exists",
                              self.name, object_id)
            return None
        for k, nv in six.iteritems(v):
            if k == "tags":
                # Merge tags
                ov = o.tags or []
                nv = sorted([
                    x for x in ov if not x.startswith(self.system.name + ":")
                ] + ["%s:%s" % (self.system.name, x) for x in nv])
            setattr(o, k, nv)
        o.save()
        return o

    def on_add(self, row):
        """
        Create new record
        """
        self.logger.debug("Add: %s", ";".join(row))
        v = self.clean(row)
        # @todo: Check record is already exists
        if self.fields[0] in v:
            del v[self.fields[0]]
        if hasattr(self.model, "remote_system"):
            o = self.find_object(v)
        else:
            o = None
        if o:
            self.c_change += 1
            # Lost&found object with same remote_id
            self.logger.debug("Lost and Found object")
            vv = {
                "remote_system": v["remote_system"],
                "remote_id": v["remote_id"]
            }
            # for fn, nv in zip(self.fields[1:], row[1:]):
            for fn, nv in v.iteritems():
                if fn in vv:
                    continue
                if getattr(o, fn) != nv:
                    vv[fn] = nv
            self.change_object(o.id, vv)
            # Restore mappings
            self.set_mappings(row[0], o.id)
        else:
            self.c_add += 1
            o = self.create_object(v)
            self.set_mappings(row[0], o.id)

    def on_change(self, o, n):
        """
        Create change record
        """
        self.logger.debug("Change: %s", ";".join(n))
        self.c_change += 1
        v = self.clean(n)
        vv = {"remote_system": v["remote_system"], "remote_id": v["remote_id"]}
        for fn, (ov, nv) in zip(self.fields[1:],
                                itertools.izip_longest(o[1:], n[1:])):
            if ov != nv:
                self.logger.debug("   %s: %s -> %s", fn, ov, nv)
                vv[fn] = v[fn]
        if n[0] in self.mappings:
            self.change_object(self.mappings[n[0]], vv)
        else:
            self.logger.error("Cannot map id '%s'. Skipping.", n[0])

    def on_delete(self, row):
        """
        Delete record
        """
        self.pending_deletes += [(row[0], ";".join(row))]

    def purge(self):
        """
        Perform pending deletes
        """
        for r_id, msg in reversed(self.pending_deletes):
            self.logger.debug("Delete: %s", msg)
            self.c_delete += 1
            try:
                obj = self.model.objects.get(pk=self.mappings[r_id])
                obj.delete()
            except ValueError as e:  # Reffered Error
                self.logger.error("%s", str(e))
                self.reffered_errors += [(r_id, msg)]
            except self.model.DoesNotExist:
                pass  # Already deleted
        self.pending_deletes = []

    def save_state(self):
        """
        Save current state
        """
        if not self.new_state_path:
            return
        self.logger.info("Summary: %d new, %d changed, %d removed", self.c_add,
                         self.c_change, self.c_delete)
        self.logger.info("Error delete by reffered: %s",
                         "\n".join(self.reffered_errors))
        t = time.localtime()
        archive_path = os.path.join(
            self.archive_dir,
            "import-%04d-%02d-%02d-%02d-%02d-%02d.csv.gz" % tuple(t[:6]))
        self.logger.info("Moving %s to %s", self.new_state_path, archive_path)
        if self.new_state_path.endswith(".gz"):
            # Simply move the file
            shutil.move(self.new_state_path, archive_path)
        else:
            # Compress the file
            self.logger.info("Compressing")
            with open(self.new_state_path, "r") as s:
                with gzip.open(archive_path, "w") as d:
                    d.write(s.read())
            os.unlink(self.new_state_path)
        self.logger.info("Saving mappings to %s", self.mappings_path)
        mdata = "\n".join("%s,%s" % (k, self.mappings[k])
                          for k in sorted(self.mappings))
        safe_rewrite(self.mappings_path, mdata)

    def clean(self, row):
        """
        Cleanup row and return a dict of field name -> value
        """
        r = dict((k, self.clean_map[k](v)) for k, v in zip(self.fields, row))
        # Fill integration fields
        r["remote_system"] = self.system.remote_system
        r["remote_id"] = self.clean_str(row[0])
        return r

    def clean_str(self, value):
        if value:
            if isinstance(value, str):
                return unicode(value, "utf-8")
            elif not isinstance(value, six.string_types):
                return str(value)
            else:
                return value
        else:
            return None

    def clean_map_str(self, mappings, value):
        value = self.clean_str(value)
        if value:
            try:
                value = mappings[value]
            except KeyError:
                raise self.Deferred
        return value

    def clean_bool(self, value):
        if value == "":
            return None
        try:
            return int(value) != 0
        except ValueError:
            pass
        value = value.lower()
        return value in ("t", "true", "y", "yes")

    def clean_reference(self, mappings, r_model, value):
        if not value:
            return None
        else:
            # @todo: Get proper mappings
            try:
                value = mappings[value]
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def clean_int_reference(self, mappings, r_model, value):
        if not value:
            return None
        else:
            # @todo: Get proper mappings
            try:
                value = int(mappings[value])
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def set_mappings(self, rv, lv):
        self.logger.debug("Set mapping remote: %s, local: %s", rv, lv)
        self.mappings[str(rv)] = str(lv)

    def update_document_clean_map(self):
        from mongoengine.fields import BooleanField, ReferenceField
        from noc.lib.nosql import PlainReferenceField, ForeignKeyField

        for fn, ft in six.iteritems(self.model._fields):
            if fn not in self.clean_map:
                continue
            if isinstance(ft, BooleanField):
                self.clean_map[fn] = self.clean_bool
            elif isinstance(ft, (PlainReferenceField, ReferenceField)):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type)
            elif isinstance(ft, ForeignKeyField):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_int_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type)
            elif fn in self.mapped_fields:
                self.clean_map[fn] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[fn]))

    def update_model_clean_map(self):
        from django.db.models import BooleanField, ForeignKey
        from noc.core.model.fields import DocumentReferenceField

        for f in self.model._meta.fields:
            if f.name not in self.clean_map:
                continue
            if isinstance(f, BooleanField):
                self.clean_map[f.name] = self.clean_bool
            elif isinstance(f, DocumentReferenceField):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.document)
            elif isinstance(f, ForeignKey):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.rel.to)
            elif f.name in self.mapped_fields:
                self.clean_map[f.name] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[f.name]))

    def check(self, chain):
        self.logger.info("Checking")
        # Get constraints
        if self.is_document:
            # Document
            required_fields = [
                f.name for f in self.model._fields.itervalues()
                if f.required or f.unique
            ]
            unique_fields = [
                f.name for f in self.model._fields.itervalues() if f.unique
            ]
        else:
            # Model
            required_fields = [
                f.name for f in self.model._meta.fields if not f.blank
            ]
            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
            ]
        if not required_fields and not unique_fields:
            self.logger.info("Nothing to check, skipping")
            return 0
        # Prepare data
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            return 0
        new_state = csv.reader(ns)
        r_index = set(
            self.fields.index(f) for f in required_fields if f in self.fields)
        u_index = set(
            self.fields.index(f) for f in unique_fields
            if f not in self.ignore_unique)
        m_index = set(self.fields.index(f) for f in self.mapped_fields)
        uv = set()
        m_data = {}  # field_number -> set of mapped ids
        # Load mapped ids
        for f in self.mapped_fields:
            line = chain.get_loader(self.mapped_fields[f])
            ls = line.get_new_state()
            if not ls:
                ls = line.get_current_state()
            ms = csv.reader(ls)
            m_data[self.fields.index(f)] = set(row[0] for row in ms)
        # Process data
        n_errors = 0
        for row in new_state:
            lr = len(row)
            # Check required fields
            for i in r_index:
                if not row[i]:
                    self.logger.error(
                        "ERROR: Required field #%d(%s) is missed in row: %s",
                        i, self.fields[i], ",".join(row))
                    n_errors += 1
                    continue
            # Check unique fields
            for i in u_index:
                v = row[i]
                if (i, v) in uv:
                    self.logger.error(
                        "ERROR: Field #%d(%s) value is not unique: %s", i,
                        self.fields[i], ",".join(row))
                    n_errors += 1
                else:
                    uv.add((i, v))
            # Check mapped fields
            for i in m_index:
                if i >= lr:
                    continue
                v = row[i]
                if v and v not in m_data[i]:
                    self.logger.error(
                        "ERROR: Field #%d(%s) == '%s' refers to non-existent record: %s",
                        i, self.fields[i], row[i], ",".join(row))
                    n_errors += 1
        if n_errors:
            self.logger.info("%d errors found", n_errors)
        else:
            self.logger.info("No errors found")
        return n_errors

    def check_diff(self):
        def dump(cmd, row):
            print("%s %s" % (cmd, ",".join(row)))

        print("--- %s.%s" % (self.chain.system.name, self.name))
        ns = self.get_new_state()
        if not ns:
            return
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                dump("+", n)
            elif o and n is None:
                dump("-", o)
            else:
                dump("/", o)
                dump("\\", n)

    def check_diff_summary(self):
        i, u, d = 0, 0, 0
        ns = self.get_new_state()
        if not ns:
            return i, u, d
        current_state = csv.reader(self.get_current_state())
        new_state = csv.reader(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                i += 1
            elif o and n is None:
                d += 1
            else:
                u += 1
        return i, u, d
Exemple #3
0
class BaseLoader(object):
    """
    Import directory structure:
    var/
        import/
            <system name>/
                <loader name>/
                    import.jsonl[.ext]  -- state to load, must have .ext extension
                                         according to selected compressor
                    mappings.csv -- ID mappings
                    archive/
                        import-YYYY-MM-DD-HH-MM-SS.jsonl.ext -- imported state

    Import file format: CSV, unix end of lines, UTF-8, comma-separated
    First column - record id in the terms of connected system,
    other columns must be defined in *fields* variable.

    File must be sorted by first field either as string or as numbers,
    sort order must not be changed.

    mappings.csv - CSV, unix end of lines, UTF-8 comma separated
    mappings of ID between NOC and remote system. Populated by loader
    automatically.

    :param fields: List of either field names or tuple of
        (field name, related loader name)
    """

    # Loader name
    name: str
    # Loader model (Database)
    model = None
    # Data model
    data_model: BaseModel

    # List of tags to add to the created records
    tags = []

    rx_archive = re.compile(r"^import-\d{4}(?:-\d{2}){5}.jsonl%s$" %
                            compressor.ext.replace(".", r"\."))

    # Discard records which cannot be dereferenced
    discard_deferred = False
    # Ignore auto-generated unique fields
    ignore_unique = {"bi_id"}

    REPORT_INTERVAL = 1000

    class Deferred(Exception):
        pass

    def __init__(self, chain):
        self.chain = chain
        self.system = chain.system
        self.logger = PrefixLoggerAdapter(
            logger, "%s][%s" % (self.system.name, self.name))
        self.disable_mappings = False
        self.import_dir = os.path.join(config.path.etl_import,
                                       self.system.name, self.name)
        self.archive_dir = os.path.join(self.import_dir, "archive")
        self.mappings_path = os.path.join(self.import_dir, "mappings.csv")
        self.mappings = {}
        self.new_state_path = None
        self.c_add = 0
        self.c_change = 0
        self.c_delete = 0
        # Mapped fields
        self.mapped_fields = self.data_model.get_mapped_fields()
        # Build clean map
        self.clean_map = {}  # field name -> clean function
        self.pending_deletes: List[Tuple[str,
                                         BaseModel]] = []  # (id, BaseModel)
        self.referred_errors: List[Tuple[str,
                                         BaseModel]] = []  # (id, BaseModel)
        if self.is_document:
            import mongoengine.errors

            unique_fields = [
                f.name for f in self.model._fields.values()
                if f.unique and f.name not in self.ignore_unique
            ]
            self.integrity_exception = mongoengine.errors.NotUniqueError
        else:
            # Third-party modules
            import django.db.utils

            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
                and f.name not in self.ignore_unique
            ]
            self.integrity_exception = django.db.utils.IntegrityError
        if unique_fields:
            self.unique_field = unique_fields[0]
        else:
            self.unique_field = None
        self.has_remote_system: bool = hasattr(self.model, "remote_system")

    @property
    def is_document(self):
        """
        Returns True if model is Document, False - if Model
        """
        return hasattr(self.model, "_fields")

    def load_mappings(self):
        """
        Load mappings file
        """
        if self.model:
            if self.is_document:
                self.update_document_clean_map()
            else:
                self.update_model_clean_map()
        if not os.path.exists(self.mappings_path):
            return
        self.logger.info("Loading mappings from %s", self.mappings_path)
        with open(self.mappings_path) as f:
            reader = csv.reader(f)
            for k, v in reader:
                self.mappings[self.clean_str(k)] = v
        self.logger.info("%d mappings restored", len(self.mappings))

    def get_new_state(self) -> Optional[TextIOWrapper]:
        """
        Returns file object of new state, or None when not present
        """
        # Try import.csv
        path = compressor.get_path(
            os.path.join(self.import_dir, "import.jsonl"))
        if not os.path.isfile(path):
            return None
        logger.info("Loading from %s", path)
        self.new_state_path = path
        return compressor(path, "r").open()

    def get_current_state(self) -> TextIOWrapper:
        """
        Returns file object of current state or None
        """
        self.load_mappings()
        if not os.path.isdir(self.archive_dir):
            self.logger.info("Creating archive directory: %s",
                             self.archive_dir)
            try:
                os.mkdir(self.archive_dir)
            except OSError as e:
                self.logger.error("Failed to create directory: %s (%s)",
                                  self.archive_dir, e)
                # @todo: Die
        if os.path.isdir(self.archive_dir):
            fn = list(
                sorted(f for f in os.listdir(self.archive_dir)
                       if self.rx_archive.match(f)))
        else:
            fn = []
        if not fn:
            return StringIO("")
        path = os.path.join(self.archive_dir, fn[-1])
        logger.info("Current state from %s", path)
        return compressor(path, "r").open()

    def iter_jsonl(
            self,
            f: TextIOWrapper,
            data_model: Optional[BaseModel] = None) -> Iterable[BaseModel]:
        """
        Iterate over JSONl stream and yield model instances
        :param f:
        :param data_model:
        :return:
        """
        dm = data_model or self.data_model
        for line in f:
            yield dm.parse_raw(line.replace("\\r", ""))

    def diff(
        self,
        old: Iterable[BaseModel],
        new: Iterable[BaseModel],
        include_fields: Set = None
    ) -> Iterable[Tuple[Optional[BaseModel], Optional[BaseModel]]]:
        """
        Compare old and new CSV files and yield pair of matches
        * old, new -- when changed
        * old, None -- when removed
        * None, new -- when added
        """

        o = next(old, None)
        n = next(new, None)
        while o or n:
            if not o:
                # New
                yield None, n
                n = next(new, None)
            elif not n:
                # Removed
                yield o, None
                o = next(old, None)
            else:
                if n.id == o.id:
                    # Changed
                    if n.dict(include=include_fields) != o.dict(
                            include=include_fields):
                        yield o, n
                    n = next(new, None)
                    o = next(old, None)
                elif n.id < o.id:
                    # Added
                    yield None, n
                    n = next(new, None)
                else:
                    # Removed
                    yield o, None
                    o = next(old, None)

    def load(self):
        """
        Import new data
        """
        self.logger.info("Importing")
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            self.load_mappings()
            return
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        deferred_add = []
        deferred_change = []
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                try:
                    self.on_add(n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_add += [n]
            elif o and n is None:
                self.on_delete(o)
            else:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    if not self.discard_deferred:
                        deferred_change += [(o, n)]
            rn = self.c_add + self.c_change + self.c_delete
            if rn > 0 and rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Add deferred records
        while len(deferred_add):
            nd = []
            for row in deferred_add:
                try:
                    self.on_add(row)
                except self.Deferred:
                    nd += [row]
            if len(nd) == len(deferred_add):
                raise Exception("Unable to defer references")
            deferred_add = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)
        # Change deferred records
        while len(deferred_change):
            nd = []
            for o, n in deferred_change:
                try:
                    self.on_change(o, n)
                except self.Deferred:
                    nd += [(o, n)]
            if len(nd) == len(deferred_change):
                raise Exception("Unable to defer references")
            deferred_change = nd
            rn = self.c_add + self.c_change + self.c_delete
            if rn % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", rn)

    def find_object(self, v: Dict[str, Any]) -> Optional[Any]:
        """
        Find object by remote system/remote id
        :param v:
        :return:
        """
        self.logger.debug("Find object: %s", v)
        if not self.has_remote_system:
            return None
        if not v.get("remote_system") or not v.get("remote_id"):
            self.logger.warning("RS or RID not found")
            return None
        find_query = {
            "remote_system": v.get("remote_system"),
            "remote_id": v.get("remote_id")
        }
        try:
            return self.model.objects.get(**find_query)
        except self.model.MultipleObjectsReturned:
            if self.unique_field:
                find_query[self.unique_field] = v.get(self.unique_field)
                r = self.model.objects.filter(**find_query)
                if not r:
                    r = self.model.objects.filter(**find_query)
                return list(r)[-1]
            raise self.model.MultipleObjectsReturned
        except self.model.DoesNotExist:
            self.logger.debug("Object not found")
            return None

    def create_object(self, v):
        """
        Create object with attributes. Override to save complex
        data structures
        """
        self.logger.debug("Create object")
        for k, nv in v.items():
            if k == "tags":
                # Merge tags
                nv = sorted("%s:%s" % (self.system.name, x) for x in nv)
                v[k] = nv
        o = self.model(**v)
        try:
            o.save()
        except self.integrity_exception as e:
            self.logger.warning("Integrity error: %s", e)
            assert self.unique_field
            if not self.is_document:
                from django.db import connection

                connection._rollback()
            # Fallback to change object
            o = self.model.objects.get(
                **{self.unique_field: v[self.unique_field]})
            for k, nv in v.items():
                setattr(o, k, nv)
            o.save()
        return o

    def change_object(self, object_id: str, v: Dict[str, Any]):
        """
        Change object with attributes
        """
        self.logger.debug("Changed object: %s", v)
        # See: https://code.getnoc.com/noc/noc/merge_requests/49
        try:
            o = self.model.objects.get(pk=object_id)
        except self.model.DoesNotExist:
            self.logger.error("Cannot change %s:%s: Does not exists",
                              self.name, object_id)
            return None
        for k, nv in v.items():
            if k == "tags":
                # Merge tags
                ov = o.tags or []
                nv = sorted([
                    x for x in ov if not (x.startswith(self.system.name + ":")
                                          or x == "remote:deleted")
                ] + ["%s:%s" % (self.system.name, x) for x in nv])
            setattr(o, k, nv)
        o.save()
        return o

    def on_add(self, item: BaseModel) -> None:
        """
        Create new record
        """
        self.logger.debug("Add: %s", item.json())
        v = self.clean(item)
        # @todo: Check record is already exists
        if "id" in v:
            del v["id"]
        o = self.find_object(v)
        if o:
            self.c_change += 1
            # Lost&found object with same remote_id
            self.logger.debug("Lost and Found object")
            vv = {
                "remote_system": v["remote_system"],
                "remote_id": v["remote_id"]
            }
            for fn, nv in v.items():
                if fn in vv:
                    continue
                if getattr(o, fn) != nv:
                    vv[fn] = nv
            self.change_object(o.id, vv)
        else:
            self.c_add += 1
            o = self.create_object(v)
        self.set_mappings(item.id, o.id)

    def on_change(self, o: BaseModel, n: BaseModel):
        """
        Create change record
        """
        self.logger.debug("Change: %s", n.json())
        self.c_change += 1
        nv = self.clean(n)
        changes = {
            "remote_system": nv["remote_system"],
            "remote_id": nv["remote_id"]
        }
        ov = self.clean(o)
        for fn in self.data_model.__fields__:
            if fn == "id":
                continue
            if ov[fn] != nv[fn]:
                self.logger.debug("   %s: %s -> %s", fn, ov[fn], nv[fn])
                changes[fn] = nv[fn]
        if n.id in self.mappings:
            self.change_object(self.mappings[n.id], changes)
        else:
            self.logger.error("Cannot map id '%s'. Skipping.", n.id)

    def on_delete(self, item: BaseModel):
        """
        Delete record
        """
        self.pending_deletes += [(item.id, item)]

    def purge(self):
        """
        Perform pending deletes
        """
        for r_id, msg in reversed(self.pending_deletes):
            self.logger.debug("Delete: %s", msg)
            self.c_delete += 1
            try:
                obj = self.model.objects.get(pk=self.mappings[r_id])
                obj.delete()
            except ValueError as e:  # Referred Error
                self.logger.error("%s", str(e))
                self.referred_errors += [(r_id, msg)]
            except KeyError as e:
                # Undefined mappings
                self.logger.error("%s", str(e))
            except self.model.DoesNotExist:
                pass  # Already deleted
        self.pending_deletes = []

    def save_state(self):
        """
        Save current state
        """
        if not self.new_state_path:
            return
        self.logger.info("Summary: %d new, %d changed, %d removed", self.c_add,
                         self.c_change, self.c_delete)
        self.logger.info("Error delete by referred: %s",
                         "\n".join(b.json() for _, b in self.referred_errors))
        t = time.localtime()
        archive_path = os.path.join(
            self.archive_dir,
            compressor.get_path("import-%04d-%02d-%02d-%02d-%02d-%02d.jsonl" %
                                tuple(t[:6])),
        )
        self.logger.info("Moving %s to %s", self.new_state_path, archive_path)
        if self.new_state_path.endswith(compressor.ext):
            # Simply move the file
            shutil.move(self.new_state_path, archive_path)
        else:
            # Compress the file
            self.logger.info("Compressing")
            with open(self.new_state_path,
                      "r") as s, compressor(archive_path, "w") as d:
                d.write(s.read())
            os.unlink(self.new_state_path)
        self.logger.info("Saving mappings to %s", self.mappings_path)
        mdata = "\n".join("%s,%s" % (k, self.mappings[k])
                          for k in sorted(self.mappings))
        safe_rewrite(self.mappings_path, mdata)

    def clean(self, item: BaseModel) -> Dict[str, Any]:
        """
        Cleanup row and return a dict of field name -> value
        """
        r = {
            k: self.clean_map.get(k, self.clean_any)(v)
            for k, v in item.dict().items()
        }
        # Fill integration fields
        r["remote_system"] = self.system.remote_system
        r["remote_id"] = self.clean_str(item.id)
        return r

    def clean_any(self, value: Any) -> Any:
        return value

    def clean_str(self, value) -> Optional[str]:
        if value:
            if isinstance(value, str):
                return smart_text(value)
            elif not isinstance(value, str):
                return str(value)
            else:
                return value
        else:
            return None

    def clean_map_str(self, mappings, value):
        value = self.clean_str(value)
        if self.disable_mappings and not mappings:
            return value
        elif value:
            try:
                value = mappings[value]
            except KeyError:
                self.logger.warning("Deferred. Unknown map value: %s", value)
                raise self.Deferred
        return value

    def clean_bool(self, value: str) -> Optional[bool]:
        if value == "":
            return None
        try:
            return int(value) != 0
        except ValueError:
            pass
        value = value.lower()
        return value in ("t", "true", "y", "yes")

    def clean_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = mappings[value]
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def clean_int_reference(self, mappings, r_model, value):
        if not value:
            return None
        elif self.disable_mappings and not mappings:
            return value
        else:
            # @todo: Get proper mappings
            try:
                value = int(mappings[value])
            except KeyError:
                self.logger.info("Deferred. Unknown value %s:%s", r_model,
                                 value)
                raise self.Deferred()
            return self.chain.cache[r_model, value]

    def set_mappings(self, rv, lv):
        self.logger.debug("Set mapping remote: %s, local: %s", rv, lv)
        self.mappings[str(rv)] = str(lv)

    def update_document_clean_map(self):
        from mongoengine.fields import BooleanField, ReferenceField
        from noc.core.mongo.fields import PlainReferenceField, ForeignKeyField

        self.logger.debug("Update Document clean map")
        for fn, ft in self.model._fields.items():
            if fn not in self.data_model.__fields__:
                continue
            if isinstance(ft, BooleanField):
                self.clean_map[fn] = self.clean_bool
            elif isinstance(ft, (PlainReferenceField, ReferenceField)):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif isinstance(ft, ForeignKeyField):
                if fn in self.mapped_fields:
                    self.clean_map[fn] = functools.partial(
                        self.clean_int_reference,
                        self.chain.get_mappings(self.mapped_fields[fn]),
                        ft.document_type,
                    )
            elif fn in self.mapped_fields:
                self.clean_map[fn] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[fn]))

    def update_model_clean_map(self):
        from django.db.models import BooleanField, ForeignKey
        from noc.core.model.fields import DocumentReferenceField

        self.logger.debug("Update Model clean map")
        for f in self.model._meta.fields:
            if f.name not in self.data_model.__fields__:
                continue
            if isinstance(f, BooleanField):
                self.clean_map[f.name] = self.clean_bool
            elif isinstance(f, DocumentReferenceField):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.document,
                    )
            elif isinstance(f, ForeignKey):
                if f.name in self.mapped_fields:
                    self.clean_map[f.name] = functools.partial(
                        self.clean_reference,
                        self.chain.get_mappings(self.mapped_fields[f.name]),
                        f.remote_field.model,
                    )
            elif f.name in self.mapped_fields:
                self.clean_map[f.name] = functools.partial(
                    self.clean_map_str,
                    self.chain.get_mappings(self.mapped_fields[f.name]))

    def check(self, chain):
        self.logger.info("Checking")
        # Get constraints
        if self.is_document:
            # Document
            required_fields = [
                f.name for f in self.model._fields.values()
                if f.required or f.unique
            ]
            unique_fields = [
                f.name for f in self.model._fields.values() if f.unique
            ]
        else:
            # Model
            required_fields = [
                f.name for f in self.model._meta.fields if not f.blank
            ]
            unique_fields = [
                f.name for f in self.model._meta.fields
                if f.unique and f.name != self.model._meta.pk.name
            ]
        if not required_fields and not unique_fields:
            self.logger.info("Nothing to check, skipping")
            return 0
        self.logger.debug("[%s] Required fields: %s", self.model,
                          required_fields)
        self.logger.debug("[%s] Unique fields: %s", self.model, unique_fields)
        self.logger.debug("[%s] Mapped fields: %s", self.model,
                          self.mapped_fields)
        # Prepare data
        ns = self.get_new_state()
        if not ns:
            self.logger.info("No new state, skipping")
            return 0
        new_state = self.iter_jsonl(ns)
        uv = set()
        m_data = {}  # field_number -> set of mapped ids
        # Load mapped ids
        for f in self.mapped_fields:
            line = chain.get_loader(self.mapped_fields[f])
            ls = line.get_new_state()
            if not ls:
                ls = line.get_current_state()
            ms = self.iter_jsonl(ls, data_model=line.data_model)
            m_data[self.data_model.__fields__[f].name] = set(row.id
                                                             for row in ms)
        # Process data
        n_errors = 0
        for row in new_state:
            row = row.dict()
            lr = len(row)
            # Check required fields
            for f in required_fields:
                if f not in self.data_model.__fields__:
                    continue
                if f not in row:
                    self.logger.error(
                        "ERROR: Required field #(%s) is missed in row: %s",
                        f,
                        # self.fields[i],
                        row,
                    )
                    n_errors += 1
                    continue
            # Check unique fields
            for f in unique_fields:
                if f in self.ignore_unique:
                    continue
                v = row[f]
                if v in uv:
                    self.logger.error(
                        "ERROR: Field #(%s) value is not unique: %s",
                        f,
                        # self.fields[i],
                        row,
                    )
                    n_errors += 1
                else:
                    uv.add(v)
            # Check mapped fields
            for i, f in enumerate(self.mapped_fields):
                if i >= lr:
                    continue
                v = row[f]
                if v and v not in m_data[f]:
                    self.logger.error(
                        "ERROR: Field #%d(%s) == '%s' refers to non-existent record: %s",
                        i,
                        f,
                        row[f],
                        row,
                    )
                    n_errors += 1
        if n_errors:
            self.logger.info("%d errors found", n_errors)
        else:
            self.logger.info("No errors found")
        return n_errors

    def check_diff(self):
        def dump(cmd, row):
            print("%s %s" % (cmd, row.json()))

        print("--- %s.%s" % (self.chain.system.name, self.name))
        ns = self.get_new_state()
        if not ns:
            return
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                dump("+", n)
            elif o and n is None:
                dump("-", o)
            else:
                dump("/", o)
                dump("\\", n)

    def check_diff_summary(self):
        i, u, d = 0, 0, 0
        ns = self.get_new_state()
        if not ns:
            return i, u, d
        current_state = self.iter_jsonl(self.get_current_state())
        new_state = self.iter_jsonl(ns)
        for o, n in self.diff(current_state, new_state):
            if o is None and n:
                i += 1
            elif o and n is None:
                d += 1
            else:
                u += 1
        return i, u, d
Exemple #4
0
class BaseExtractor(object):
    """
    Data extractor interface. Subclasses must provide
    *iter_data* method
    """

    name = None
    PREFIX = config.path.etl_import
    REPORT_INTERVAL = 1000
    # Type of model
    model: Type[BaseModel]
    # List of rows to be used as constant data
    data: List[BaseModel] = []
    # Suppress deduplication message
    suppress_deduplication_log: bool = False

    def __init__(self, system):
        self.system = system
        self.config = system.config
        self.logger = PrefixLoggerAdapter(logger, "%s][%s" % (system.name, self.name))
        self.import_dir = os.path.join(self.PREFIX, system.name, self.name)
        self.fatal_problems: List[Problem] = []
        self.quality_problems: List[Problem] = []

    def register_quality_problem(
        self, line: int, p_class: str, message: str, row: List[Any]
    ) -> None:
        self.quality_problems += [
            Problem(line=line + 1, is_rej=False, p_class=p_class, message=message, row=row)
        ]

    def register_fatal_problem(self, line: int, p_class: str, message: str, row: List[Any]) -> None:
        self.fatal_problems += [
            Problem(line=line + 1, is_rej=True, p_class=p_class, message=message, row=row)
        ]

    def ensure_import_dir(self) -> None:
        """
        Ensure import directory is exists
        :return:
        """
        if os.path.isdir(self.import_dir):
            return
        self.logger.info("Creating directory %s", self.import_dir)
        os.makedirs(self.import_dir)

    def get_new_state(self) -> io.TextIOWrapper:
        self.ensure_import_dir()
        path = compressor.get_path(os.path.join(self.import_dir, "import.jsonl"))
        self.logger.info("Writing to %s", path)
        return compressor(path, "w").open()

    @contextlib.contextmanager
    def with_new_state(self):
        """
        New state context manager. Usage::

        with e.with_new_state() as f:
            ...

        :return:
        """
        f = self.get_new_state()
        try:
            yield f
        finally:
            f.close()

    def get_problem_file(self) -> io.TextIOWrapper:
        self.ensure_import_dir()
        path = compressor.get_path(os.path.join(self.import_dir, "import.csv.rej"))
        self.logger.info("Writing to %s", path)
        return compressor(path, "w").open()

    @contextlib.contextmanager
    def with_problem_file(self):
        """
        New state context manager. Usage::

        with e.with_problem_file() as f:
            ...

        :return:
        """
        f = self.get_problem_file()
        try:
            yield f
        finally:
            f.close()

    def iter_data(self) -> Iterable[Union[BaseModel, Tuple[Any, ...]]]:
        yield from self.data

    def filter(self, row):
        return True

    def clean(self, row):
        return row

    def extract(self):
        def q(s):
            if s == "" or s is None:
                return ""
            elif isinstance(s, str):
                return smart_text(s)
            else:
                return str(s)

        def get_model(raw) -> BaseModel:
            if isinstance(raw, BaseModel):
                return raw
            return self.model.from_iter(q(x) for x in row)

        # Fetch data
        self.logger.info("Extracting %s from %s", self.name, self.system.name)
        t0 = perf_counter()
        data: List[BaseModel] = []
        n = 0
        seen = set()
        for row in self.iter_data():
            if not self.filter(row):
                continue
            row = self.clean(row)
            # Do not use get_model(self.clean(row)), to zip_longest broken row
            row = get_model(row)
            if row.id in seen:
                if not self.suppress_deduplication_log:
                    self.logger.error("Duplicated row truncated: %r", row)
                continue
            seen.add(row.id)
            data += [row]
            n += 1
            if n % self.REPORT_INTERVAL == 0:
                self.logger.info("   ... %d records", n)
        dt = perf_counter() - t0
        speed = n / dt
        self.logger.info("%d records extracted in %.2fs (%d records/s)", n, dt, speed)
        # Write
        with self.with_new_state() as f:
            for n, item in enumerate(sorted(data, key=operator.attrgetter("id"))):
                if n:
                    f.write("\n")
                f.write(item.json(exclude_defaults=True, exclude_unset=True))
        if self.fatal_problems or self.quality_problems:
            self.logger.warning(
                "Detect problems on extracting, fatal: %d, quality: %d",
                len(self.fatal_problems),
                len(self.quality_problems),
            )
            self.logger.warning("Line num\tType\tProblem string")
            for p in self.fatal_problems:
                self.logger.warning(
                    "Fatal problem, line was rejected: %s\t%s\t%s" % (p.line, p.p_class, p.message)
                )
            for p in self.quality_problems:
                self.logger.warning(
                    "Data quality problem in line:  %s\t%s\t%s" % (p.line, p.p_class, p.message)
                )
            # Dump problem to file
            try:
                with self.with_problem_file() as f:
                    writer = csv.writer(f, delimiter=";")
                    for p in itertools.chain(self.quality_problems, self.fatal_problems):
                        writer.writerow(
                            [smart_text(c) for c in p.row]
                            + [
                                "Fatal problem, line was rejected"
                                if p.is_rej
                                else "Data quality problem"
                            ]
                            + [p.message.encode("utf-8")]
                        )
            except IOError as e:
                self.logger.error("Error when saved problems %s", e)
        else:
            self.logger.info("No problems detected")
Exemple #5
0
class ProfileChecker(object):
    base_logger = logging.getLogger("profilechecker")
    _rules_cache = cachetools.TTLCache(10, ttl=60)
    _re_cache = {}

    def __init__(
        self,
        address=None,
        pool=None,
        logger=None,
        snmp_community=None,
        calling_service="profilechecker",
        snmp_version=None,
    ):
        self.address = address
        self.pool = pool
        self.logger = PrefixLoggerAdapter(
            logger or self.base_logger, "%s][%s" % (self.pool or "", self.address or "")
        )
        self.result_cache = {}  # (method, param) -> result
        self.error = None
        self.snmp_community = snmp_community
        self.calling_service = calling_service
        self.snmp_version = snmp_version or [SNMP_v2c]
        self.ignoring_snmp = False
        if self.snmp_version is None:
            self.logger.error("SNMP is not supported. Ignoring")
            self.ignoring_snmp = True
        if not self.snmp_community:
            self.logger.error("No SNMP credentials. Ignoring")
            self.ignoring_snmp = True

    def find_profile(self, method, param, result):
        """
        Find profile by method
        :param method: Fingerprint getting method
        :param param: Method params
        :param result: Getting params result
        :return:
        """
        r = defaultdict(list)
        d = self.get_rules()
        for k, value in sorted(six.iteritems(d), key=lambda x: x[0]):
            for v in value:
                r[v] += value[v]
        if (method, param) not in r:
            self.logger.warning("Not find rule for method: %s %s", method, param)
            return
        for match_method, value, action, profile, rname in r[(method, param)]:
            if self.is_match(result, match_method, value):
                self.logger.info("Matched profile: %s (%s)", profile, rname)
                # @todo: process MAYBE rule
                return profile

    def get_profile(self):
        """
        Returns profile for object, or None when not known
        """
        snmp_result = ""
        http_result = ""
        for ruleset in self.iter_rules():
            for (method, param), actions in ruleset:
                try:
                    result = self.do_check(method, param)
                    if not result:
                        continue
                    if "snmp" in method:
                        snmp_result = result
                    if "http" in method:
                        http_result = result
                    for match_method, value, action, profile, rname in actions:
                        if self.is_match(result, match_method, value):
                            self.logger.info("Matched profile: %s (%s)", profile, rname)
                            # @todo: process MAYBE rule
                            return profile
                except NOCError as e:
                    self.logger.error(e.message)
                    self.error = str(e.message)
                    return None
        if snmp_result or http_result:
            self.error = "Not find profile for OID: %s or HTTP string: %s" % (
                snmp_result,
                http_result,
            )
        elif not snmp_result:
            self.error = "Cannot fetch snmp data, check device for SNMP access"
        elif not http_result:
            self.error = "Cannot fetch HTTP data, check device for HTTP access"
        self.logger.info("Cannot detect profile: %s", self.error)
        return None

    def get_error(self):
        """
        Get error message
        :return:
        """
        return self.error

    @classmethod
    @cachetools.cachedmethod(operator.attrgetter("_rules_cache"), lock=lambda _: rules_lock)
    def get_profile_check_rules(cls):
        return list(ProfileCheckRule.objects.all().order_by("preference"))

    def get_rules(self):
        """
        Load ProfileCheckRules and return a list, grouped by preferences
        [{
            (method, param) -> [(
                    match_method,
                    value,
                    action,
                    profile,
                    rule_name
                ), ...]

        }]
        """
        self.logger.info('Compiling "Profile Check rules"')
        d = {}  # preference -> (method, param) -> [rule, ..]
        for r in self.get_profile_check_rules():
            if "snmp" in r.method and self.ignoring_snmp:
                continue
            if r.preference not in d:
                d[r.preference] = {}
            k = (r.method, r.param)
            if k not in d[r.preference]:
                d[r.preference][k] = []
            d[r.preference][k] += [(r.match_method, r.value, r.action, r.profile, r.name)]
        return d

    def iter_rules(self):
        d = self.get_rules()
        for p in sorted(d):
            yield list(six.iteritems(d[p]))

    @classmethod
    @cachetools.cachedmethod(operator.attrgetter("_re_cache"))
    def get_re(cls, regexp):
        return re.compile(regexp)

    def do_check(self, method, param):
        """
        Perform check
        """
        self.logger.debug("do_check(%s, %s)", method, param)
        if (method, param) in self.result_cache:
            self.logger.debug("Using cached value")
            return self.result_cache[method, param]
        h = getattr(self, "check_%s" % method, None)
        if not h:
            self.logger.error("Invalid check method '%s'. Ignoring", method)
            return None
        result = h(param)
        self.result_cache[method, param] = result
        return result

    def check_snmp_v2c_get(self, param):
        """
        Perform SNMP v2c GET. Param is OID or symbolic name
        """
        try:
            param = mib[param]
        except KeyError:
            self.logger.error("Cannot resolve OID '%s'. Ignoring", param)
            return None
        for v in self.snmp_version:
            if v == SNMP_v1:
                r = self.snmp_v1_get(param)
            elif v == SNMP_v2c:
                r = self.snmp_v2c_get(param)
            else:
                raise NOCError(msg="Unsupported SNMP version")
            if r:
                return r

    def check_http_get(self, param):
        """
        Perform HTTP GET check. Param can be URL path or :<port>/<path>
        """
        url = "http://%s%s" % (self.address, param)
        return self.http_get(url)

    def check_https_get(self, param):
        """
        Perform HTTPS GET check. Param can be URL path or :<port>/<path>
        """
        url = "https://%s%s" % (self.address, param)
        return self.https_get(url)

    def is_match(self, result, method, value):
        """
        Returns True when result matches value
        """
        if method == "eq":
            return result == value
        elif method == "contains":
            return value in result
        elif method == "re":
            return bool(self.get_re(value).search(result))
        else:
            self.logger.error("Invalid match method '%s'. Ignoring", method)
            return False

    def snmp_v1_get(self, param):
        """
        Perform SNMP v1 request. May be overridden for testing
        :param param:
        :return:
        """
        self.logger.info("SNMP v1 GET: %s", param)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).snmp_v1_get(self.address, self.snmp_community, param)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def snmp_v2c_get(self, param):
        """
        Perform SNMP v2c request. May be overridden for testing
        :param param:
        :return:
        """
        self.logger.info("SNMP v2c GET: %s", param)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).snmp_v2c_get(self.address, self.snmp_community, param)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def http_get(self, url):
        """
        Perform HTTP request. May be overridden for testing
        :param url: Request URL
        :return:
        """
        self.logger.info("HTTP Request: %s", url)
        try:
            return open_sync_rpc(
                "activator", pool=self.pool, calling_service=self.calling_service
            ).http_get(url, True)
        except RPCError as e:
            self.logger.error("RPC Error: %s", e)
            return None

    def https_get(self, url):
        """
        Perform HTTP request. May be overridden for testing
        :param url: Request URL
        :return:
        """
        return self.http_get(url)