示例#1
0
class TestSimpleLogger(unittest.TestCase):
    """Test writing messages to stdout and stderr using SimpleLogger."""
    def setUp(self):
        self.logger = SimpleLogger()
        self.old_stdout = sys.stdout
        self.old_stderr = sys.stderr
        sys.stdout = StringIO()
        sys.stderr = StringIO()

    def test_stdout(self):
        """Write message to stdout and read it back."""
        msg = "test message to stdout"
        self.logger.log(msg)
        sys.stdout.seek(0)
        output = sys.stdout.read().strip()
        self.assertEqual(msg, output)

    def test_stderr(self):
        """Write message to stderr and read it back."""
        msg = "test message to stderr"
        self.logger.errlog(msg)
        sys.stderr.seek(0)
        output = sys.stderr.read().strip()
        self.assertEqual(msg, output)

    def tearDown(self):
        sys.stdout = self.old_stdout
        sys.stderr = self.old_stderr
示例#2
0
class FileDownloader:
    """
    Main downloader class. Contains queue of items to download. Once download is triggered, certain number
    of download threads is created. Downloader is waiting until download queue is empty and all child threads
    are finished.
    """
    def __init__(self):
        self.queue = Queue()
        self.logger = SimpleLogger()

    def add(self, download_item):
        """Add DownloadItem object into the queue."""
        self.queue.put(download_item)

    def run(self):
        """Start processing download queue using multiple threads."""
        self.logger.log("Downloading started.")
        threads = []
        for i in range(min(THREADS, self.queue.qsize())):
            self.logger.log("Starting thread %d." % i)
            thread = FileDownloadThread(self.queue, self.logger)
            thread.setDaemon(True)
            thread.start()
            threads.append(thread)

        for i, thread in enumerate(threads):
            thread.join()
            self.logger.log("Stopping thread %d." % i)
        self.logger.log("Downloading finished.")
示例#3
0
class CveRepoStore:
    """
    Interface to store cve list metadata (e.g lastmodified).
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.repo = []
        self.conn = DatabaseHandler.get_connection()
        self.cve_store = CveStore()

    def list_lastmodified(self):
        """
        Fetch map of lastmodified dates for cve lists we've downloaded in the past.
        """
        lastmodified = {}
        cur = self.conn.cursor()
        cur.execute("select key, value from metadata where key like 'nistcve:%'")
        for row in cur.fetchall():
            label = row[0][8:]        # strip nistcve: prefix
            lastmodified[label] = row[1]
        cur.close()
        return lastmodified

    def _import_repo(self, label, lastmodified):
        key = 'nistcve:' + label
        cur = self.conn.cursor()
        cur.execute("select id from metadata where key = %s", (key,))
        repo_id = cur.fetchone()
        if not repo_id:
            cur.execute("insert into metadata (key, value) values (%s, %s) returning id",
                        (key, lastmodified))
            repo_id = cur.fetchone()
        else:
            # Update repository timestamp
            cur.execute("update metadata set value = %s where id = %s", (lastmodified, repo_id[0],))
        cur.close()
        self.conn.commit()
        return repo_id[0]

    def store(self, repo):
        """
        Store list of CVEs in the database.
        """
        self.logger.log("Syncing CVE list: %s" % repo.label)
        self._import_repo(repo.label, repo.meta.get_lastmodified())
        self.logger.log("Syncing CVEs : %s" % repo.get_count())
        self.cve_store.store(repo)
示例#4
0
class FileUnpacker:
    """
    Class unpacking queued files.
    Files to unpack are collected and then all unpacked at once into their locations.
    Gz, Xz, Bz2 formats are supported.
    """
    def __init__(self):
        self.queue = []
        self.logger = SimpleLogger()

    def add(self, file_path):
        """Add compressed file path to queue."""
        self.queue.append(file_path)

    @staticmethod
    def _get_unpack_func(file_path):
        if file_path.endswith(".gz"):
            return gzip.open
        elif file_path.endswith(".xz"):
            return lzma.open
        elif file_path.endswith(".bz2"):
            return bz2.open
        return None

    def _unpack(self, file_path):
        unpack_func = self._get_unpack_func(file_path)
        if unpack_func:
            with unpack_func(file_path, "rb") as packed:
                unpacked_file_path = file_path.rsplit(".", maxsplit=1)[0]
                with open(unpacked_file_path, "wb") as unpacked:
                    while True:
                        chunk = packed.read(CHUNK_SIZE)
                        if chunk == b"":
                            break
                        unpacked.write(chunk)
            os.unlink(file_path)
            self.logger.log("%s -> %s" % (file_path, unpacked_file_path))
        else:
            self.logger.log("%s skipped.")

    def run(self):
        """Unpack all queued file paths."""
        self.logger.log("Unpacking started.")
        for file_path in self.queue:
            self._unpack(file_path)
        # Make queue empty to be able to reuse this class multiple times in one run
        self.queue = []
        self.logger.log("Unpacking finished.")
示例#5
0
class FileUnpacker:
    def __init__(self):
        self.queue = []
        self.logger = SimpleLogger()

    def add(self, file_path):
        self.queue.append(file_path)

    @staticmethod
    def _get_unpack_func(file_path):
        if file_path.endswith(".gz"):
            return gzip.open
        elif file_path.endswith(".xz"):
            return lzma.open
        elif file_path.endswith(".bz2"):
            return bz2.open
        else:
            return None

    def _unpack(self, file_path):
        unpack_func = self._get_unpack_func(file_path)
        if unpack_func:
            with unpack_func(file_path, "rb") as packed:
                unpacked_file_path = file_path.rsplit(".", maxsplit=1)[0]
                with open(unpacked_file_path, "wb") as unpacked:
                    while True:
                        chunk = packed.read(CHUNK_SIZE)
                        if chunk == b"":
                            break
                        unpacked.write(chunk)
            os.unlink(file_path)
            self.logger.log("%s -> %s" % (file_path, unpacked_file_path))
        else:
            self.logger.log("%s skipped.")

    def run(self):
        self.logger.log("Unpacking started.")
        for file_path in self.queue:
            self._unpack(file_path)
        self.logger.log("Unpacking finished.")
示例#6
0
class FileDownloader:
    def __init__(self):
        self.queue = Queue()
        self.logger = SimpleLogger()

    def add(self, download_item):
        self.queue.put(download_item)

    def run(self):
        self.logger.log("Downloading started.")
        threads = []
        for i in range(min(THREADS, self.queue.qsize())):
            self.logger.log("Starting thread %d." % i)
            thread = FileDownloadThread(self.queue, self.logger)
            thread.setDaemon(True)
            thread.start()
            threads.append(thread)

        for i, t in enumerate(threads):
            t.join()
            self.logger.log("Stopping thread %d." % i)
        self.logger.log("Downloading finished.")
示例#7
0
class RepositoryController:
    """
    Class for importing/syncing set of repositories into the DB.
    First, repomd from all repositories are downloaded and parsed.
    Second, primary and updateinfo repodata from repositories needing update are downloaded, parsed and imported.
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.downloader = FileDownloader()
        self.unpacker = FileUnpacker()
        self.repo_store = RepositoryStore()
        self.repositories = set()
        self.certs_tmp_directory = None
        self.certs_files = {}

    def _get_certs_tuple(self, name):
        if name in self.certs_files:
            return self.certs_files[name]["ca_cert"], self.certs_files[name][
                "cert"], self.certs_files[name]["key"]
        return None, None, None

    def _download_repomds(self):
        download_items = []
        for repository in self.repositories:
            repomd_url = urljoin(repository.repo_url, REPOMD_PATH)
            repository.tmp_directory = tempfile.mkdtemp(prefix="repo-")
            ca_cert, cert, key = self._get_certs_tuple(repository.cert_name)
            item = DownloadItem(source_url=repomd_url,
                                target_path=os.path.join(
                                    repository.tmp_directory, "repomd.xml"),
                                ca_cert=ca_cert,
                                cert=cert,
                                key=key)
            # Save for future status code check
            download_items.append(item)
            self.downloader.add(item)
        self.downloader.run()
        # Return failed downloads
        return {
            item.target_path: item.status_code
            for item in download_items
            if item.status_code not in VALID_HTTP_CODES
        }

    def _read_repomds(self, failed):
        """Reads all downloaded repomd files. Checks if their download failed and checks if their metadata are
           newer than metadata currently in DB.
        """
        # Fetch current list of repositories from DB
        db_repositories = self.repo_store.list_repositories()
        for repository in self.repositories:
            repomd_path = os.path.join(repository.tmp_directory, "repomd.xml")
            if repomd_path not in failed:
                repomd = RepoMD(repomd_path)
                # Was repository already synced before?
                if repository.repo_label in db_repositories:
                    db_revision = db_repositories[
                        repository.repo_label]["revision"]
                else:
                    db_revision = None
                downloaded_revision = datetime.fromtimestamp(
                    repomd.get_revision(), tz=timezone.utc)
                # Repository is synced for the first time or has newer revision
                if db_revision is None or downloaded_revision > db_revision:
                    repository.repomd = repomd
                else:
                    self.logger.log(
                        "Downloaded repo %s (%s) is not newer than repo in DB (%s)."
                        % (repository.repo_label, str(downloaded_revision),
                           str(db_revision)))
            else:
                self.logger.log("Download failed: %s (HTTP CODE %d)" %
                                (urljoin(repository.repo_url,
                                         REPOMD_PATH), failed[repomd_path]))

    def _download_metadata(self, batch):
        for repository in batch:
            # primary_db has higher priority, use primary.xml if not found
            try:
                repository.md_files[
                    "primary_db"] = repository.repomd.get_metadata(
                        "primary_db")["location"]
            except RepoMDTypeNotFound:
                repository.md_files[
                    "primary"] = repository.repomd.get_metadata(
                        "primary")["location"]
            # updateinfo.xml may be missing completely
            try:
                repository.md_files[
                    "updateinfo"] = repository.repomd.get_metadata(
                        "updateinfo")["location"]
            except RepoMDTypeNotFound:
                pass

            # queue metadata files for download
            for md_location in repository.md_files.values():
                ca_cert, cert, key = self._get_certs_tuple(
                    repository.cert_name)
                self.downloader.add(
                    DownloadItem(source_url=urljoin(repository.repo_url,
                                                    md_location),
                                 target_path=os.path.join(
                                     repository.tmp_directory,
                                     os.path.basename(md_location)),
                                 ca_cert=ca_cert,
                                 cert=cert,
                                 key=key))
        self.downloader.run()

    def _unpack_metadata(self, batch):
        for repository in batch:
            for md_type in repository.md_files:
                self.unpacker.add(
                    os.path.join(
                        repository.tmp_directory,
                        os.path.basename(repository.md_files[md_type])))
                # FIXME: this should be done in different place?
                repository.md_files[md_type] = os.path.join(
                    repository.tmp_directory,
                    os.path.basename(repository.md_files[md_type])).rsplit(
                        ".", maxsplit=1)[0]
        self.unpacker.run()

    def clean_repodata(self, batch):
        """Clean downloaded repodata of all repositories in batch."""
        for repository in batch:
            if repository.tmp_directory:
                shutil.rmtree(repository.tmp_directory)
                repository.tmp_directory = None
            self.repositories.remove(repository)

    def _clean_certificate_cache(self):
        if self.certs_tmp_directory:
            shutil.rmtree(self.certs_tmp_directory)
            self.certs_tmp_directory = None
            self.certs_files = {}

    def add_synced_repositories(self):
        """Queue all previously synced repositories."""
        repos = self.repo_store.list_repositories()
        for repo_label, repo_dict in repos.items():
            # Reference content_set_label -> content set id
            self.repo_store.content_set_to_db_id[
                repo_dict["content_set"]] = repo_dict["content_set_id"]
            self.repositories.add(
                Repository(repo_label,
                           repo_dict["url"],
                           content_set=repo_dict["content_set"],
                           cert_name=repo_dict["cert_name"],
                           ca_cert=repo_dict["ca_cert"],
                           cert=repo_dict["cert"],
                           key=repo_dict["key"]))

    # pylint: disable=too-many-arguments
    def add_repository(self,
                       repo_label,
                       repo_url,
                       content_set=None,
                       cert_name=None,
                       ca_cert=None,
                       cert=None,
                       key=None):
        """Queue repository to import/check updates."""
        repo_url = repo_url.strip()
        if not repo_url.endswith("/"):
            repo_url += "/"
        self.repositories.add(
            Repository(repo_label,
                       repo_url,
                       content_set=content_set,
                       cert_name=cert_name,
                       ca_cert=ca_cert,
                       cert=cert,
                       key=key))

    def _write_certificate_cache(self):
        certs = {}
        for repository in self.repositories:
            if repository.cert_name:
                certs[repository.cert_name] = {
                    "ca_cert": repository.ca_cert,
                    "cert": repository.cert,
                    "key": repository.key
                }
        if certs:
            self.certs_tmp_directory = tempfile.mkdtemp(prefix="certs-")
            for cert_name in certs:
                self.certs_files[cert_name] = {}
                for cert_type in ["ca_cert", "cert", "key"]:
                    # Cert is not None
                    if certs[cert_name][cert_type]:
                        cert_path = os.path.join(
                            self.certs_tmp_directory,
                            "%s.%s" % (cert_name, cert_type))
                        with open(cert_path, "w") as cert_file:
                            cert_file.write(certs[cert_name][cert_type])
                        self.certs_files[cert_name][cert_type] = cert_path
                    else:
                        self.certs_files[cert_name][cert_type] = None

    def store(self):
        """Sync all queued repositories. Process repositories in batches due to disk space and memory usage."""
        self.logger.log("Checking %d repositories." % len(self.repositories))

        self._write_certificate_cache()

        # Download all repomd files first
        failed = self._download_repomds()
        self.logger.log("%d repomd.xml files failed to download." %
                        len(failed))
        self._read_repomds(failed)

        # Filter all repositories without repomd attribute set (failed download, downloaded repomd is not newer)
        batches = BatchList()
        to_skip = []
        for repository in self.repositories:
            if repository.repomd:
                batches.add_item(repository)
            else:
                to_skip.append(repository)
        self.clean_repodata(to_skip)
        self.logger.log("%d repositories skipped." % len(to_skip))
        self.logger.log("Syncing %d repositories." %
                        sum(len(l) for l in batches))

        # Download and process repositories in batches (unpacked metadata files can consume lot of disk space)
        for batch in batches:
            self._download_metadata(batch)
            self._unpack_metadata(batch)
            for repository in batch:
                repository.load_metadata()
                self.repo_store.store(repository)
                repository.unload_metadata()
            self.clean_repodata(batch)

        self._clean_certificate_cache()
示例#8
0
class ProductStore: # pylint: disable=too-few-public-methods
    """
    Class providing interface for storing product info.
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.conn = DatabaseHandler.get_connection()
        # Access this dictionary from repository_store to reference content set table.
        self.cs_to_dbid = {}

    def _import_products(self, products):
        engid_to_dbid = {}
        self.logger.log("Syncing %d products." % len(products))
        cur = self.conn.cursor()
        cur.execute("select id, redhat_eng_product_id from product where redhat_eng_product_id in %s",
                    (tuple(products.keys()),))
        for row in cur.fetchall():
            engid_to_dbid[row[1]] = row[0]
        missing_products = []
        for product in products:
            if product not in engid_to_dbid:
                missing_products.append((product, products[product]["name"]))
        self.logger.log("Products already in DB: %d" % len(engid_to_dbid))
        self.logger.log("Products to import: %d" % len(missing_products))
        if missing_products:
            execute_values(cur, """insert into product (redhat_eng_product_id, name) values %s
                                   returning id, redhat_eng_product_id""", missing_products,
                           page_size=len(missing_products))
            for row in cur.fetchall():
                engid_to_dbid[row[1]] = row[0]
        cur.close()
        self.conn.commit()
        return engid_to_dbid

    def _import_content_sets(self, products):
        engid_to_dbid = self._import_products(products)
        all_content_set_labels = [cs for product in products.values() for cs in product["content_sets"]]
        self.logger.log("Syncing %d content sets." % len(all_content_set_labels))
        cur = self.conn.cursor()
        cur.execute("select id, label from content_set where label in %s", (tuple(all_content_set_labels),))
        for row in cur.fetchall():
            self.cs_to_dbid[row[1]] = row[0]
        missing_content_sets = []
        for product in products:
            for content_set in products[product]["content_sets"]:
                if content_set not in self.cs_to_dbid:
                    # label, name, product_id
                    missing_content_sets.append((content_set, products[product]["content_sets"][content_set],
                                                 engid_to_dbid[product]))
        self.logger.log("Content sets already in DB: %d" % len(self.cs_to_dbid))
        self.logger.log("Content sets to import: %d" % len(missing_content_sets))
        if missing_content_sets:
            execute_values(cur, """insert into content_set (label, name, product_id) values %s
                                   returning id, label""", missing_content_sets, page_size=len(missing_content_sets))
            for row in cur.fetchall():
                self.cs_to_dbid[row[1]] = row[0]
        cur.close()
        self.conn.commit()

    def store(self, products):
        """
        Import all product info from input dictionary.
        """
        self._import_content_sets(products)
示例#9
0
class CveStore:
    """
    Class interface for listing and storing CVEs in database.
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.conn = DatabaseHandler.get_connection()

    def list_lastmodified(self):
        """
        List lastmodified times from database.
        """
        lastmodified = {}
        cur = self.conn.cursor()
        cur.execute(
            "select key, value from metadata where key like 'nistcve:'")
        for row in cur.fetchall():
            label = row[0][8:]  # strip nistcve: prefix
            lastmodified[label] = row[1]
        cur.close()
        return lastmodified

    def _populate_severities(self, repo):
        severities = {}
        cur = self.conn.cursor()
        cur.execute("select id, name from severity")
        for row in cur.fetchall():
            severities[row[1]] = row[0]
        missing_severities = set()
        for cve in repo.list_cves():
            severity = _dget(cve, "impact", "baseMetricV3", "cvssV3",
                             "baseSeverity")
            if severity is not None:
                severity = severity.capitalize()
                if severity not in severities:
                    missing_severities.add((severity, ))
        self.logger.log("Severities missing in DB: %d" %
                        len(missing_severities))
        if missing_severities:
            execute_values(
                cur,
                "insert into severity (name) values %s returning id, name",
                missing_severities,
                page_size=len(missing_severities))
            for row in cur.fetchall():
                severities[row[1]] = row[0]
        cur.close()
        self.conn.commit()
        return severities

    def _populate_cwes(self, cursor, cve_data):
        # pylint: disable=R0914
        # Populate CWE table
        cursor.execute("SELECT name, id FROM cwe")
        cwe_name_id = cursor.fetchall()
        cwe_name = [cwe[0] for cwe in cwe_name_id]
        cwe_list = []
        cwe_link_map = dict()
        for cve in cve_data.values():
            for cwe in cve["cwe_list"]:
                cwe_list.append(cwe['cwe_name'])
                cwe_link_map[cwe['cwe_name']] = cwe['link']

        import_set = set(cwe_list) - set(cwe_name)
        import_set = [(name, cwe_link_map[name]) for name in import_set]
        self.logger.log("CWEs to import: %d" % len(import_set))
        new_cwes = ()
        if import_set:
            execute_values(
                cursor,
                "INSERT INTO cwe (name, link) values %s returning name, id",
                list(import_set),
                page_size=len(list(import_set)))
            new_cwes = cursor.fetchall()

        # Populate cve_cwe mappings
        mapping_set = []
        for entry in cve_data.values():
            cwe_names = [x["cwe_name"] for x in entry["cwe_list"]]
            mapping_set.extend([(entry['id'], cwe_name)
                                for cwe_name in cwe_names])

        # Some entries are not commited to DB yet, get them from last insert
        all_cwes = dict(tuple(new_cwes) + tuple(cwe_name_id))
        # Lookup IDs for missing cwes
        cve_cwe_map = _map_name_to_id(set(mapping_set), all_cwes)
        cursor.execute("SELECT cve_id, cwe_id FROM cve_cwe")
        to_import = set(cve_cwe_map) - set(cursor.fetchall())
        self.logger.log("CVE to CWE mapping to import: %d" % len(to_import))
        if to_import:
            execute_values(
                cursor,
                "INSERT INTO cve_cwe (cve_id, cwe_id) values %s returning cve_id, cwe_id",
                list(to_import),
                page_size=len(to_import))

    def _populate_cves(self, repo):  # pylint: disable=too-many-locals
        severity_map = self._populate_severities(repo)
        cur = self.conn.cursor()
        cve_data = {}
        for cve in repo.list_cves():
            cve_name = _dget(cve, "cve", "CVE_data_meta", "ID")

            cve_desc_list = _dget(cve, "cve", "description",
                                  "description_data")
            severity = _dget(cve, "impact", "baseMetricV3", "cvssV3",
                             "baseSeverity")
            url_list = _dget(cve, "cve", "references", "reference_data")
            modified_date = datetime.strptime(_dget(cve, "lastModifiedDate"),
                                              "%Y-%m-%dT%H:%MZ")
            published_date = datetime.strptime(_dget(cve, "publishedDate"),
                                               "%Y-%m-%dT%H:%MZ")
            cwe_data = _dget(cve, "cve", "problemtype", "problemtype_data")
            cwe_list = _process_cwe_list(cwe_data)
            redhat_url, secondary_url = _process_url_list(cve_name, url_list)
            cve_data[cve_name] = {
                "description":
                _desc(cve_desc_list, "lang", "en", "value"),
                "severity_id":
                severity_map[severity.capitalize()]
                if severity is not None else None,
                "cvss3_score":
                _dget(cve, "impact", "baseMetricV3", "cvssV3", "baseScore"),
                "redhat_url":
                redhat_url,
                "cwe_list":
                cwe_list,
                "secondary_url":
                secondary_url,
                "published_date":
                published_date,
                "modified_date":
                modified_date,
                "iava":
                None,
            }

        if cve_data:
            names = [(key, ) for key in cve_data]
            execute_values(cur,
                           """select id, name from cve
                              inner join (values %s) t(name)
                              using (name)
                           """,
                           names,
                           page_size=len(names))
            for row in cur.fetchall():
                cve_data[row[1]]["id"] = row[0]
                # Remove to not insert this CVE

        to_import = [(name, values["description"], values["severity_id"],
                      values["published_date"], values["modified_date"],
                      values["cvss3_score"], values["iava"],
                      values["redhat_url"], values["secondary_url"])
                     for name, values in cve_data.items()
                     if "id" not in values]
        self.logger.log("CVEs to import: %d" % len(to_import))
        to_update = [
            (values["id"], name, values["description"], values["severity_id"],
             values["published_date"], values["modified_date"],
             values["cvss3_score"], values["iava"], values["redhat_url"],
             values["secondary_url"]) for name, values in cve_data.items()
            if "id" in values
        ]

        self.logger.log("CVEs to update: %d" % len(to_update))

        if to_import:
            execute_values(
                cur,
                """insert into cve (name, description, severity_id, published_date, modified_date,
                              cvss3_score, iava, redhat_url, secondary_url) values %s returning id, name""",
                list(to_import),
                page_size=len(to_import))
            for row in cur.fetchall():
                cve_data[row[1]]["id"] = row[0]

        if to_update:
            execute_values(
                cur,
                """update cve set name = v.name,
                                             description = v.description,
                                             severity_id = v.severity_id,
                                             published_date = v.published_date,
                                             modified_date = v.modified_date,
                                             redhat_url = v.redhat_url,
                                             secondary_url = v.secondary_url,
                                             cvss3_score = v.cvss3_score,
                                             iava = v.iava
                              from (values %s)
                              as v(id, name, description, severity_id, published_date, modified_date, cvss3_score,
                              iava, redhat_url, secondary_url)
                              where cve.id = v.id """,
                list(to_update),
                page_size=len(to_update),
                template=
                b"(%s, %s, %s, %s::int, %s, %s, %s::numeric, %s, %s, %s)")
        self._populate_cwes(cur, cve_data)
        cur.close()
        self.conn.commit()
        return cve_data

    def store(self, repo):
        """
        Store / update cve information in database.
        """
        self.logger.log("Syncing %d CVEs." % repo.get_count())
        self._populate_cves(repo)
        self.logger.log("Syncing CVEs finished.")
示例#10
0
class RepositoryStore:
    def __init__(self):
        self.logger = SimpleLogger()
        self.repositories = []
        self.conn = psycopg2.connect(database=DEFAULT_DB_NAME,
                                     user=DEFAULT_DB_USER,
                                     password=DEFAULT_DB_PASSWORD,
                                     host=DEFAULT_DB_HOST,
                                     port=DEFAULT_DB_PORT)
        self.checksum_types = self._lookup_checksum_types()

    def _lookup_checksum_types(self):
        checksums = {}
        cur = self.conn.cursor()
        cur.execute("select id, name from checksum_type")
        for row in cur.fetchall():
            checksums[row[1]] = row[0]
        cur.close()
        return checksums

    def _import_repository(self, repo_url):
        cur = self.conn.cursor()
        cur.execute("select id from repo where name = %s",
                    (dummy_name(repo_url), ))
        repo_id = cur.fetchone()
        if not repo_id:
            # FIXME: add product logic
            cur.execute(
                "insert into repo (name, eol) values (%s, false) returning id",
                (dummy_name(repo_url), ))
            repo_id = cur.fetchone()
        cur.close()
        self.conn.commit()
        return repo_id[0]

    def _populate_evrs(self, packages):
        cur = self.conn.cursor()
        evr_map = {}
        unique_evrs = set()
        for pkg in packages:
            unique_evrs.add((pkg["epoch"], pkg["ver"], pkg["rel"]))
        self.logger.log("Unique EVRs in repository: %d" % len(unique_evrs))
        execute_values(cur,
                       """select id, epoch, version, release from evr
                       inner join (values %s) t(epoch, version, release)
                       using (epoch, version, release)""",
                       list(unique_evrs),
                       page_size=len(unique_evrs))
        for row in cur.fetchall():
            evr_map[(row[1], row[2], row[3])] = row[0]
            # Remove to not insert this evr
            unique_evrs.remove((row[1], row[2], row[3]))
        self.logger.log("EVRs already in DB: %d" % len(evr_map))
        self.logger.log("EVRs to import: %d" % len(unique_evrs))
        if unique_evrs:
            # FIXME: insert also evr_t column
            execute_values(
                cur,
                """insert into evr (epoch, version, release) values %s
                           returning id, epoch, version, release""",
                list(unique_evrs),
                page_size=len(unique_evrs))
            for row in cur.fetchall():
                evr_map[(row[1], row[2], row[3])] = row[0]
        cur.close()
        self.conn.commit()
        return evr_map

    def _populate_packages(self, packages):
        evr_map = self._populate_evrs(packages)
        cur = self.conn.cursor()
        pkg_map = {}
        checksums = set()
        for pkg in packages:
            checksums.add(
                (self.checksum_types[pkg["checksum_type"]], pkg["checksum"]))
        execute_values(cur,
                       """select id, checksum_type_id, checksum from package
                          inner join (values %s) t(checksum_type_id, checksum)
                          using (checksum_type_id, checksum)
                       """,
                       list(checksums),
                       page_size=len(checksums))
        for row in cur.fetchall():
            pkg_map[(row[1], row[2])] = row[0]
            # Remove to not insert this package
            checksums.remove((row[1], row[2]))
        self.logger.log("Packages already in DB: %d" % len(pkg_map))
        self.logger.log("Packages to import: %d" % len(checksums))
        if checksums:
            import_data = []
            for pkg in packages:
                if (self.checksum_types[pkg["checksum_type"]],
                        pkg["checksum"]) in checksums:
                    import_data.append(
                        (pkg["name"], evr_map[(pkg["epoch"], pkg["ver"],
                                               pkg["rel"])],
                         self.checksum_types[pkg["checksum_type"]],
                         pkg["checksum"]))
            execute_values(
                cur,
                """insert into package (name, evr_id, checksum_type_id, checksum) values %s
                              returning id, checksum_type_id, checksum""",
                import_data,
                page_size=len(import_data))
            for row in cur.fetchall():
                pkg_map[(row[1], row[2])] = row[0]
        cur.close()
        self.conn.commit()
        return pkg_map

    def _associate_packages(self, pkg_map, repo_id):
        cur = self.conn.cursor()
        associated_with_repo = set()
        cur.execute("select pkg_id from pkg_repo where repo_id = %s",
                    (repo_id, ))
        for row in cur.fetchall():
            associated_with_repo.add(row[0])
        self.logger.log("Packages associated to repository: %d" %
                        len(associated_with_repo))
        to_associate = []
        for pkg_id in pkg_map.values():
            if pkg_id in associated_with_repo:
                associated_with_repo.remove(pkg_id)
            else:
                to_associate.append(pkg_id)
        self.logger.log("New packages to associate with repository: %d" %
                        len(to_associate))
        self.logger.log("Packages to disassociate with repository: %d" %
                        len(associated_with_repo))
        if to_associate:
            execute_values(cur,
                           "insert into pkg_repo (repo_id, pkg_id) values %s",
                           [(repo_id, pkg_id) for pkg_id in to_associate],
                           page_size=len(to_associate))
        # Are there packages to disassociate?
        if associated_with_repo:
            cur.execute(
                "delete from pkg_repo where repo_id = %s and pkg_id in %s", (
                    repo_id,
                    tuple(associated_with_repo),
                ))
        cur.close()
        self.conn.commit()

    def _import_packages(self, repo_id, packages):
        package_map = self._populate_packages(packages)
        self._associate_packages(package_map, repo_id)

    def _lookup_severity(self, severity):
        if severity is None:
            severity = "None"
        cur = self.conn.cursor()
        cur.execute("select id from severity where name = %s", (severity, ))
        severity_id = cur.fetchone()
        if not severity_id:
            # FIXME: optimize
            cur.execute("insert into severity (name) values (%s) returning id",
                        (severity, ))
            severity_id = cur.fetchone()
        cur.close()
        self.conn.commit()
        return severity_id[0]

    def _import_updates(self, repo_id, updates):
        cur = self.conn.cursor()
        # FIXME: optimize
        for update in updates:
            cur.execute("select id from errata where name = %s",
                        (update["id"], ))
            update_id = cur.fetchone()
            if not update_id:
                severity_id = self._lookup_severity(update["severity"])
                cur.execute(
                    "insert into errata (name, synopsis, severity_id) values (%s, %s, %s) returning id",
                    (
                        update["id"],
                        update["title"],
                        severity_id,
                    ))
                update_id = cur.fetchone()
        cur.close()
        self.conn.commit()

    def store(self, repository):
        self.logger.log("Processing repository: %s" %
                        dummy_name(repository.repo_url))
        repo_id = self._import_repository(repository.repo_url)
        self.logger.log("Importing %d packages." %
                        repository.get_package_count())
        self._import_packages(repo_id, repository.list_packages())
        self.logger.log("Importing packages finished.")
        self.logger.log("Importing %d updates." %
                        repository.get_update_count())
        self._import_updates(repo_id, repository.list_updates())
        self.logger.log("Importing updates finished.")
示例#11
0
class UpdateStore: # pylint: disable=too-few-public-methods
    """
    Class providing interface for storing updates and related info.
    All updates from repository are imported to the DB at once.
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.conn = DatabaseHandler.get_connection()

    def _get_nevras_in_repo(self, repo_id):
        cur = self.conn.cursor()
        # Select all packages synced from current repository and save them to dict accessible by NEVRA
        nevras_in_repo = {}
        cur.execute("""select p.id, p.name, evr.epoch, evr.version, evr.release, a.name
                               from package p inner join
                                    evr on p.evr_id = evr.id inner join
                                    arch a on p.arch_id = a.id inner join
                                    pkg_repo pr on p.id = pr.pkg_id and pr.repo_id = %s""", (repo_id,))
        for row in cur.fetchall():
            nevras_in_repo[(row[1], row[2], row[3], row[4], row[5])] = row[0]
        cur.close()
        return nevras_in_repo

    def _get_associations_todo(self, repo_id, updates, update_map, update_to_packages):
        nevras_in_repo = self._get_nevras_in_repo(repo_id)
        to_associate = []
        for update in updates:
            update_id = update_map[update["id"]]
            for pkg in update["pkglist"]:
                nevra = (pkg["name"], pkg["epoch"], pkg["ver"], pkg["rel"], pkg["arch"])
                if nevra not in nevras_in_repo:
                    self.logger.log("NEVRA associated with %s not found in repository: (%s)" %
                                    (update["id"], ",".join(nevra)))
                    continue
                package_id = nevras_in_repo[nevra]
                if update_id in update_to_packages and package_id in update_to_packages[update_id]:
                    # Already associated, remove from set
                    update_to_packages[update_id].remove(package_id)
                else:
                    # Not associated -> associate
                    to_associate.append((package_id, update_id))

        # Disassociate rest of package IDs
        to_disassociate = []
        for update_id in update_to_packages:
            for package_id in update_to_packages[update_id]:
                to_disassociate.append((package_id, update_id))

        return to_associate, to_disassociate

    def _populate_severities(self, updates):
        severities = {}
        cur = self.conn.cursor()
        cur.execute("select id, name from severity")
        for row in cur.fetchall():
            severities[row[1]] = row[0]
        missing_severities = set()
        for update in updates:
            if str(update["severity"]) not in severities:
                missing_severities.add((str(update["severity"]),))
        self.logger.log("Severities missing in DB: %d" % len(missing_severities))
        if missing_severities:
            execute_values(cur, "insert into severity (name) values %s returning id, name",
                           missing_severities, page_size=len(missing_severities))
            for row in cur.fetchall():
                severities[row[1]] = row[0]
        cur.close()
        self.conn.commit()
        return severities

    def _populate_updates(self, updates):
        severity_map = self._populate_severities(updates)
        cur = self.conn.cursor()
        update_map = {}
        names = set()
        for update in updates:
            names.add((update["id"],))
        if names:
            execute_values(cur,
                           """select id, name from errata
                              inner join (values %s) t(name)
                              using (name)
                           """, list(names), page_size=len(names))
            for row in cur.fetchall():
                update_map[row[1]] = row[0]
                # Remove to not insert this update
                names.remove((row[1],))
        self.logger.log("Updates already in DB: %d" % len(update_map))
        self.logger.log("Updates to import: %d" % len(names))
        if names:
            import_data = []
            for update in updates:
                if (update["id"],) in names:
                    import_data.append((update["id"], update["title"], severity_map[str(update["severity"])]))
            execute_values(cur,
                           """insert into errata (name, synopsis, severity_id) values %s
                              returning id, name""",
                           import_data, page_size=len(import_data))
            for row in cur.fetchall():
                update_map[row[1]] = row[0]
        cur.close()
        self.conn.commit()
        return update_map

    def _associate_packages(self, updates, update_map, repo_id):
        cur = self.conn.cursor()
        # Select packages already associated with updates, from current repository only
        # Save them to dict: errata_id -> set(package_id)
        update_to_packages = {}
        if update_map:
            cur.execute("""select e.id, pe.pkg_id
                           from errata e inner join
                                pkg_errata pe on e.id = pe.errata_id inner join
                                pkg_repo pr on pe.pkg_id = pr.pkg_id and pr.repo_id = %s
                           where e.id in %s""", (repo_id, tuple(update_map.values()),))
            for row in cur.fetchall():
                if row[0] not in update_to_packages:
                    update_to_packages[row[0]] = set()
                update_to_packages[row[0]].add(row[1])

        to_associate, to_disassociate = self._get_associations_todo(repo_id, updates, update_map, update_to_packages)

        self.logger.log("New update-package associations: %d" % len(to_associate))
        self.logger.log("Update-package disassociations: %d" % len(to_disassociate))

        if to_associate:
            execute_values(cur, "insert into pkg_errata (pkg_id, errata_id) values %s",
                           list(to_associate), page_size=len(to_associate))

        if to_disassociate:
            cur.execute("delete from pkg_errata where (pkg_id, errata_id) in %s", (tuple(to_disassociate),))

        cur.close()
        self.conn.commit()

    def _associate_updates(self, update_map, repo_id):
        cur = self.conn.cursor()
        associated_with_repo = set()
        cur.execute("select errata_id from errata_repo where repo_id = %s", (repo_id,))
        for row in cur.fetchall():
            associated_with_repo.add(row[0])
        self.logger.log("Updates associated with repository: %d" % len(associated_with_repo))
        to_associate = []
        for update_id in update_map.values():
            if update_id in associated_with_repo:
                associated_with_repo.remove(update_id)
            else:
                to_associate.append(update_id)
        self.logger.log("New updates to associate with repository: %d" % len(to_associate))
        self.logger.log("Updates to disassociate with repository: %d" % len(associated_with_repo))
        if to_associate:
            execute_values(cur, "insert into errata_repo (repo_id, errata_id) values %s",
                           [(repo_id, update_id) for update_id in to_associate], page_size=len(to_associate))
        # Are there updates to disassociate?
        if associated_with_repo:
            cur.execute("delete from errata_repo where repo_id = %s and errata_id in %s",
                        (repo_id, tuple(associated_with_repo),))
        cur.close()
        self.conn.commit()

    def _populate_cves(self, updates):
        cur = self.conn.cursor()
        cve_map = {}
        names = set()
        for update in updates:
            for reference in update["references"]:
                if reference["type"] == "cve":
                    names.add((reference["id"],))
        if names:
            execute_values(cur,
                           """select id, name from cve
                              inner join (values %s) t(name)
                              using (name)
                           """, list(names), page_size=len(names))
            for row in cur.fetchall():
                cve_map[row[1]] = row[0]
                # Remove to not insert this CVE
                names.remove((row[1],))
        self.logger.log("CVEs already in DB: %d" % len(cve_map))
        self.logger.log("CVEs to import: %d" % len(names))
        if names:
            execute_values(cur,
                           """insert into cve (name) values %s
                              returning id, name""",
                           list(names), page_size=len(names))
            for row in cur.fetchall():
                cve_map[row[1]] = row[0]
        cur.close()
        self.conn.commit()
        return cve_map

    def _associate_cves(self, updates, update_map, cve_map):
        cur = self.conn.cursor()
        update_to_cves = {}
        if update_map:
            cur.execute("select errata_id, cve_id from errata_cve where errata_id in %s",
                        (tuple(update_map.values()),))
            for row in cur.fetchall():
                if row[0] not in update_to_cves:
                    update_to_cves[row[0]] = set()
                update_to_cves[row[0]].add(row[1])

        to_associate = []
        for update in updates:
            update_id = update_map[update["id"]]
            for cve in set([cve["id"] for cve in update["references"] if cve["type"] == "cve"]):
                cve_id = cve_map[cve]
                if update_id in update_to_cves and cve_id in update_to_cves[update_id]:
                    # Already associated, remove from set
                    update_to_cves[update_id].remove(cve_id)
                else:
                    # Not associated -> associate
                    to_associate.append((update_id, cve_id))

        # Disassociate rest of update IDs
        to_disassociate = []
        for update_id in update_to_cves:
            for cve_id in update_to_cves[update_id]:
                to_disassociate.append((update_id, cve_id))

        self.logger.log("New update-CVE associations: %d" % len(to_associate))
        self.logger.log("Update-CVE disassociations: %d" % len(to_disassociate))

        if to_associate:
            execute_values(cur, "insert into errata_cve (errata_id, cve_id) values %s",
                           to_associate, page_size=len(to_associate))

        if to_disassociate:
            cur.execute("delete from errata_cve where (errata_id, cve_id) in %s", (tuple(to_disassociate),))

        cur.close()
        self.conn.commit()

    def store(self, repo_id, updates):
        """
        Import all updates from repository into all related DB tables.
        """
        self.logger.log("Syncing %d updates." % len(updates))
        update_map = self._populate_updates(updates)
        self._associate_packages(updates, update_map, repo_id)
        self._associate_updates(update_map, repo_id)
        cve_map = self._populate_cves(updates)
        self._associate_cves(updates, update_map, cve_map)
        self.logger.log("Syncing updates finished.")
示例#12
0
class RepositoryStore:
    """
    Class providing interface for listing repositories stored in DB and storing repositories one by one.
    """
    def __init__(self):
        self.content_set_to_db_id = {}
        self.logger = SimpleLogger()
        self.conn = DatabaseHandler.get_connection()
        self.package_store = PackageStore()
        self.update_store = UpdateStore()

    def set_content_set_db_mapping(self, content_set_to_db_id):
        """Set content set to DB is mapping from product_store"""
        self.content_set_to_db_id = content_set_to_db_id

    def _get_content_set_id(self, repo):
        if repo.content_set in self.content_set_to_db_id:
            return self.content_set_to_db_id[repo.content_set]
        return None

    def list_repositories(self):
        """List repositories stored in DB. Dictionary with repository label as key is returned."""
        cur = self.conn.cursor()
        cur.execute(
            """select r.id, r.label, r.url, r.revision, cs.id, cs.label, c.name, c.ca_cert, c.cert, c.key
                       from repo r
                       left join certificate c on r.certificate_id = c.id
                       left join content_set cs on r.content_set_id = cs.id""")
        repos = {}
        for row in cur.fetchall():
            # repo_label -> repo_id, repo_url, repo_revision
            repos[row[1]] = {
                "id": row[0],
                "url": row[2],
                "revision": row[3],
                "content_set_id": row[4],
                "content_set": row[5],
                "cert_name": row[6],
                "ca_cert": row[7],
                "cert": row[8],
                "key": row[9]
            }
        cur.close()
        return repos

    def _import_certificate(self, cert_name, ca_cert, cert, key):
        cur = self.conn.cursor()
        cur.execute("select id from certificate where name = %s",
                    (cert_name, ))
        cert_id = cur.fetchone()
        if not cert_id:
            cur.execute(
                """insert into certificate (name, ca_cert, cert, key)
                        values (%s, %s, %s, %s) returning id""", (
                    cert_name,
                    ca_cert,
                    cert,
                    key,
                ))
            cert_id = cur.fetchone()
        else:
            cur.execute(
                "update certificate set ca_cert = %s, cert = %s, key = %s where name = %s",
                (
                    ca_cert,
                    cert,
                    key,
                    cert_name,
                ))
        cur.close()
        self.conn.commit()
        return cert_id[0]

    def _import_repository(self, repo):
        if repo.ca_cert:
            cert_id = self._import_certificate(repo.cert_name, repo.ca_cert,
                                               repo.cert, repo.key)
        else:
            cert_id = None
        cur = self.conn.cursor()
        cur.execute("select id from repo where label = %s",
                    (repo.repo_label, ))
        repo_id = cur.fetchone()
        content_set_id = self._get_content_set_id(repo)
        if not repo_id:
            cur.execute(
                """insert into repo (label, url, revision, eol, certificate_id, content_set_id)
                        values (%s, %s, to_timestamp(%s), false, %s, %s) returning id""",
                (
                    repo.repo_label,
                    repo.repo_url,
                    repo.repomd.get_revision(),
                    cert_id,
                    content_set_id,
                ))
            repo_id = cur.fetchone()
        else:
            # Update repository timestamp
            cur.execute(
                """update repo set revision = to_timestamp(%s), certificate_id = %s, content_set_id = %s
                        where id = %s""", (
                    repo.repomd.get_revision(),
                    cert_id,
                    content_set_id,
                    repo_id[0],
                ))
        cur.close()
        self.conn.commit()
        return repo_id[0]

    def store(self, repository):
        """
        Store single repository into DB.
        First, basic repository info is processed, then all packages, then all updates.
        Some steps may be skipped if given data doesn't exist or are already synced.
        """
        self.logger.log("Syncing repository: %s" % repository.repo_label)
        repo_id = self._import_repository(repository)
        self.package_store.store(repo_id, repository.list_packages())
        self.update_store.store(repo_id, repository.list_updates())
示例#13
0
class PackageStore:  # pylint: disable=too-few-public-methods
    """
    Class providing interface for storing packages and related info.
    All packages from repository are imported to the DB at once.
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.conn = DatabaseHandler.get_connection()

    def _populate_archs(self, packages):
        archs = {}
        cur = self.conn.cursor()
        cur.execute("select id, name from arch")
        for row in cur.fetchall():
            archs[row[1]] = row[0]
        missing_archs = set()
        for pkg in packages:
            if pkg["arch"] not in archs:
                missing_archs.add((pkg["arch"], ))
        self.logger.log("Architectures missing in DB: %d" % len(missing_archs))
        if missing_archs:
            execute_values(
                cur,
                "insert into arch (name) values %s returning id, name",
                missing_archs,
                page_size=len(missing_archs))
            for row in cur.fetchall():
                archs[row[1]] = row[0]
        cur.close()
        self.conn.commit()
        return archs

    def _populate_checksum_types(self, packages):
        checksums = {}
        cur = self.conn.cursor()
        cur.execute("select id, name from checksum_type")
        for row in cur.fetchall():
            checksums[row[1]] = row[0]
        missing_checksum_types = set()
        for pkg in packages:
            # Same checksum types can be called differently in different repositories, unify them before import
            if pkg["checksum_type"] in CHECKSUM_TYPE_ALIASES:
                pkg["checksum_type"] = CHECKSUM_TYPE_ALIASES[
                    pkg["checksum_type"]]
            if pkg["checksum_type"] not in checksums:
                missing_checksum_types.add((pkg["checksum_type"], ))
        self.logger.log("Checksum types missing in DB: %d" %
                        len(missing_checksum_types))
        if missing_checksum_types:
            execute_values(
                cur,
                "insert into checksum_type (name) values %s returning id, name",
                missing_checksum_types,
                page_size=len(missing_checksum_types))
            for row in cur.fetchall():
                checksums[row[1]] = row[0]
        cur.close()
        self.conn.commit()
        return checksums

    def _populate_evrs(self, packages):
        cur = self.conn.cursor()
        evr_map = {}
        unique_evrs = set()
        for pkg in packages:
            unique_evrs.add((pkg["epoch"], pkg["ver"], pkg["rel"]))
        self.logger.log("Unique EVRs in repository: %d" % len(unique_evrs))
        if unique_evrs:
            execute_values(cur,
                           """select id, epoch, version, release from evr
                           inner join (values %s) t(epoch, version, release)
                           using (epoch, version, release)""",
                           list(unique_evrs),
                           page_size=len(unique_evrs))
            for row in cur.fetchall():
                evr_map[(row[1], row[2], row[3])] = row[0]
                # Remove to not insert this evr
                unique_evrs.remove((row[1], row[2], row[3]))
        self.logger.log("EVRs already in DB: %d" % len(evr_map))

        to_import = []
        for (epoch, version, release) in unique_evrs:
            to_import.append(
                (epoch, version, release, epoch, version, release))
        self.logger.log("EVRs to import: %d" % len(to_import))
        if to_import:
            execute_values(
                cur,
                """insert into evr (epoch, version, release, evr) values %s
                           returning id, epoch, version, release""",
                to_import,
                template=
                b"(%s, %s, %s, (%s, rpmver_array(%s), rpmver_array(%s)))",
                page_size=len(to_import))
            for row in cur.fetchall():
                evr_map[(row[1], row[2], row[3])] = row[0]
        cur.close()
        self.conn.commit()
        return evr_map

    def _populate_packages(self, packages):
        archs = self._populate_archs(packages)
        checksum_types = self._populate_checksum_types(packages)
        evr_map = self._populate_evrs(packages)
        cur = self.conn.cursor()
        pkg_map = {}
        checksums = set()
        for pkg in packages:
            checksums.add(
                (checksum_types[pkg["checksum_type"]], pkg["checksum"]))
        if checksums:
            execute_values(
                cur,
                """select id, checksum_type_id, checksum from package
                              inner join (values %s) t(checksum_type_id, checksum)
                              using (checksum_type_id, checksum)
                           """,
                list(checksums),
                page_size=len(checksums))
            for row in cur.fetchall():
                pkg_map[(row[1], row[2])] = row[0]
                # Remove to not insert this package
                checksums.remove((row[1], row[2]))
        self.logger.log("Packages already in DB: %d" % len(pkg_map))
        self.logger.log("Packages to import: %d" % len(checksums))
        if checksums:
            import_data = []
            for pkg in packages:
                if (checksum_types[pkg["checksum_type"]],
                        pkg["checksum"]) in checksums:
                    import_data.append(
                        (pkg["name"], evr_map[(pkg["epoch"], pkg["ver"],
                                               pkg["rel"])],
                         archs[pkg["arch"]],
                         checksum_types[pkg["checksum_type"]],
                         pkg["checksum"]))
                    # Prevent duplicated insert when some package is multiple times in metadata
                    checksums.remove((checksum_types[pkg["checksum_type"]],
                                      pkg["checksum"]))
            execute_values(
                cur,
                """insert into package (name, evr_id, arch_id, checksum_type_id, checksum) values %s
                              returning id, checksum_type_id, checksum""",
                import_data,
                page_size=len(import_data))
            for row in cur.fetchall():
                pkg_map[(row[1], row[2])] = row[0]
        cur.close()
        self.conn.commit()
        return pkg_map

    def _associate_packages(self, pkg_map, repo_id):
        cur = self.conn.cursor()
        associated_with_repo = set()
        cur.execute("select pkg_id from pkg_repo where repo_id = %s",
                    (repo_id, ))
        for row in cur.fetchall():
            associated_with_repo.add(row[0])
        self.logger.log("Packages associated with repository: %d" %
                        len(associated_with_repo))
        to_associate = []
        for pkg_id in pkg_map.values():
            if pkg_id in associated_with_repo:
                associated_with_repo.remove(pkg_id)
            else:
                to_associate.append(pkg_id)
        self.logger.log("New packages to associate with repository: %d" %
                        len(to_associate))
        self.logger.log("Packages to disassociate with repository: %d" %
                        len(associated_with_repo))
        if to_associate:
            execute_values(cur,
                           "insert into pkg_repo (repo_id, pkg_id) values %s",
                           [(repo_id, pkg_id) for pkg_id in to_associate],
                           page_size=len(to_associate))
        # Are there packages to disassociate?
        if associated_with_repo:
            cur.execute(
                "delete from pkg_repo where repo_id = %s and pkg_id in %s", (
                    repo_id,
                    tuple(associated_with_repo),
                ))
        cur.close()
        self.conn.commit()

    def store(self, repo_id, packages):
        """
        Import all packages from repository into all related DB tables.
        """
        self.logger.log("Syncing %d packages." % len(packages))
        package_map = self._populate_packages(packages)
        self._associate_packages(package_map, repo_id)
        self.logger.log("Syncing packages finished.")
示例#14
0
class CveRepoController:
    """
    Controls import/sync of CVE lists into the DB.
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.downloader = FileDownloader()
        self.unpacker = FileUnpacker()
        self.cverepo_store = CveRepoStore()
        self.repos = set()
        self.db_lastmodified = {}

    def _download_meta(self):
        download_items = []
        for repo in self.repos:
            repo.tmp_directory = tempfile.mkdtemp(prefix="cverepo-")
            item = DownloadItem(
                source_url=repo.meta_url(),
                target_path=repo.meta_tmp()
            )
            # Save for future status code check
            download_items.append(item)
            self.downloader.add(item)
        self.downloader.run()
        # Return failed downloads
        return {item.target_path: item.status_code for item in download_items
                if item.status_code not in VALID_HTTP_CODES}

    def _read_meta(self, failed):
        """Reads downloaded meta files and checks for updates."""
        for repo in self.repos:
            meta_path = repo.meta_tmp()
            if meta_path not in failed:
                meta = CveMeta(meta_path)
                # already synced before?
                db_lastmodified = _dt_strptime(self.db_lastmodified.get(repo.label, None))
                meta_lastmodified = _dt_strptime(meta.get_lastmodified())
                # synced for the first time or has newer revision
                if (db_lastmodified is None
                        or meta_lastmodified is None
                        or meta_lastmodified > db_lastmodified):
                    repo.meta = meta
                else:
                    self.logger.log("Cve list '%s' has not been updated (since %s)." %
                                    (repo.label, str(db_lastmodified)))
            else:
                self.logger.log("Download failed: %s (HTTP CODE %d)" % (repo.meta_url(),
                                                                        failed[meta_path]))

    def _download_json(self, batch):
        for repo in batch:
            self.downloader.add(DownloadItem(source_url=repo.json_url(),
                                             target_path=repo.json_tmpgz()))
        self.downloader.run()

    def _unpack_json(self, batch):
        for repo in batch:
            self.unpacker.add(repo.json_tmpgz())
        self.unpacker.run()

    def clean_repo(self, batch):
        """Clean downloaded files for given batch."""
        for repo in batch:
            if repo.tmp_directory:
                shutil.rmtree(repo.tmp_directory)
                repo.tmp_directory = None
            self.repos.remove(repo)

    def add_repos(self):
        """Generate urls for CVE lists to download."""
        # Fetch current list of repositories from DB
        self.db_lastmodified = self.cverepo_store.list_lastmodified()

        # CVE files for single years should be used only for initial load
        labels = [str(y) for y in range(YEAR_SINCE, int(time.strftime("%Y"))+1)]
        for label in labels:
            if label not in self.db_lastmodified:
                self.repos.add(CveRepo(label))

        # always import incremental changes
        labels = ['recent', 'modified']
        for label in labels:
            self.repos.add(CveRepo(label))

    def store(self):
        """Sync all queued CVE lists. Runs in batches due to disk space and memory usage."""
        self.logger.log("Checking %d CVE lists." % len(self.repos))

        # Download all repomd files first
        failed = self._download_meta()
        self.logger.log("%d meta files failed to download." % len(failed))
        self._read_meta(failed)

        # filter out failed / unchanged lists
        batches = BatchList()
        to_skip = []
        for repo in self.repos:
            if repo.meta:
                batches.add_item(repo)
            else:
                to_skip.append(repo)
        self.clean_repo(to_skip)
        self.logger.log("%d CVE lists skipped." % len(to_skip))
        self.logger.log("Syncing %d CVE lists." % sum(len(l) for l in batches))

        # Download and process repositories in batches (unpacked metadata files can consume lot of disk space)
        for batch in batches:
            self._download_json(batch)
            self._unpack_json(batch)
            for repo in sorted(batch, key=lambda repo: repo.label):
                repo.load_json()
                self.cverepo_store.store(repo)
                repo.unload_json()
            self.clean_repo(batch)