예제 #1
0
 def __init__(self):
     self.logger = SimpleLogger()
     self.downloader = FileDownloader()
     self.unpacker = FileUnpacker()
     self.repo_store = RepositoryStore()
     self.repositories = set()
     self.db_repositories = {}
예제 #2
0
 def __init__(self):
     self.logger = get_logger(__name__)
     self.downloader = FileDownloader()
     self.unpacker = FileUnpacker()
     self.repo_store = RepositoryStore()
     self.repositories = set()
     self.certs_tmp_directory = None
     self.certs_files = {}
예제 #3
0
    def repo_setup(self):
        """Setup repo_store object."""
        product_store = ProductStore()
        product_store.store(PRODUCTS)

        self.repo_store = RepositoryStore()
예제 #4
0
class TestRepositoryStore:
    """TestRepositoryStore class. Test repository store"""

    @pytest.fixture
    def repo_setup(self):
        """Setup repo_store object."""
        product_store = ProductStore()
        product_store.store(PRODUCTS)

        self.repo_store = RepositoryStore()

    @pytest.mark.first
    @pytest.mark.parametrize("repository", REPOSITORIES, ids=[r[0] for r in REPOSITORIES])
    def test_repo_store(self, db_conn, repo_setup, repository):
        """Test storing repository data in DB."""

        # store repository
        self.repo_store.store(repository[1])

        cur = db_conn.cursor()
        cur.execute("select * from repo where url = '{}'".format(repository[1].repo_url))
        repo = cur.fetchone()
        cur.execute("select * from content_set where id = {}".format(repo[REPO_CS_ID]))
        content_set = cur.fetchone()
        cur.execute("select * from product where id = {}".format(content_set[CS_PRODUCT_ID]))
        product = cur.fetchone()
        cur.execute("select * from arch where id = {}".format(repo[REPO_BASEARCH_ID]))
        arch = cur.fetchone()

        assert repo[REPO_URL] == repository[1].repo_url
        assert repo[REPO_RELEASEVER] == repository[1].releasever

        assert content_set[CS_LABEL] == "cs_label"
        assert content_set[CS_NAME] == "cs_name"

        assert product[PRODUCT_NAME] == "product"
        assert product[PRODUCT_RH_ID] == 9

        assert arch[ARCH_NAME] == repository[1].basearch

    @pytest.mark.parametrize("repository", REPOSITORIES, ids=[r[0] for r in REPOSITORIES])
    def test_repo_pkgs(self, db_conn, repository):
        """Test that packages from repo are present in DB."""
        cur = db_conn.cursor()
        cur.execute("select id from repo where url = '{}'".format(repository[1].repo_url))
        repo_id = cur.fetchone()[0]
        cur.execute("select count(*) from pkg_repo where repo_id = {}".format(repo_id))
        pkg_num = cur.fetchone()[0]

        assert pkg_num == 12  # 12 packages expected from primary.xml/primary.db

    @pytest.mark.parametrize("repository", REPOSITORIES, ids=[r[0] for r in REPOSITORIES])
    def test_repo_errata(self, db_conn, repository):
        """Test that errata from repo are present in DB."""
        cur = db_conn.cursor()
        cur.execute("select id from repo where url = '{}'".format(repository[1].repo_url))
        repo_id = cur.fetchone()[0]
        cur.execute("select count(*) from errata_repo where repo_id = {}".format(repo_id))
        errata_num = cur.fetchone()[0]

        # only repository with updateifo has errata
        if repository[1].updateinfo:
            assert errata_num == 9  # 9 erata expected from primary.xml/primary.db
        else:
            assert errata_num == 0

    @pytest.mark.parametrize("repository", REPOSITORIES, ids=[r[0] for r in REPOSITORIES])
    def test_stored_packages(self, db_conn, repository):
        """Test all packages count in package table."""
        cur = db_conn.cursor()
        cur.execute("select evr.epoch, evr.version, evr.release, pn.name, arch.name "
                    "from package "
                    "join package_name pn on package.name_id = pn.id "
                    "join evr on package.evr_id = evr.id "
                    "join arch on package.arch_id = arch.id "
                    "order by evr.evr, pn.name, arch.name")
        rows = cur.fetchall()
        assert len(rows) == 18  # 18 packages expected from primary.xml/primary.db
        # check correct packages order (mainly sorting according to evr)
        assert rows[0] == ('0', '0.0.20', '5.fc27', '3Depict', 'src')
        assert rows[1] == ('0', '0.0.20', '5.fc27', '3Depict', 'x86_64')
        assert rows[2] == ('0', '0.57', '1.fc27', 'BackupPC-XS', 'src')
        assert rows[3] == ('0', '0.57', '1.fc27', 'BackupPC-XS', 'x86_64')
        assert rows[4] == ('0', '1.3.7.8', '1.fc27', '389-ds-base', 'src')
        assert rows[5] == ('0', '1.3.7.8', '1.fc27', '389-ds-base', 'x86_64')
        assert rows[6] == ('0', '1.3.7.8', '1.fc27', '389-ds-base-devel', 'i686')
        assert rows[7] == ('0', '1.3.7.8', '1.fc27', '389-ds-base-devel', 'x86_64')
        assert rows[8] == ('0', '1.3.7.8', '1.fc27', '389-ds-base-libs', 'i686')
        assert rows[9] == ('0', '1.3.7.8', '1.fc27', '389-ds-base-libs', 'x86_64')
        assert rows[10] == ('0', '1.3.7.8', '1.fc27', '389-ds-base-snmp', 'x86_64')
        assert rows[11] == ('0', '1.3.10', '7.fc27', 'CGSI-gSOAP', 'i686')
        assert rows[12] == ('0', '1.3.10', '7.fc27', 'CGSI-gSOAP', 'src')
        assert rows[13] == ('0', '1.3.10', '7.fc27', 'CGSI-gSOAP', 'x86_64')
        assert rows[14] == ('0', '2.5.2', '9.fc27', 'Agda', 'src')
        assert rows[15] == ('0', '2.5.2', '9.fc27', 'Agda', 'x86_64')
        assert rows[16] == ('0', '4.1.5', '1.fc27', 'BackupPC', 'src')
        assert rows[17] == ('0', '4.1.5', '1.fc27', 'BackupPC', 'x86_64')

    @pytest.mark.parametrize("repository", REPOSITORIES, ids=[r[0] for r in REPOSITORIES])
    def test_rpm_pkgs_count(self, db_conn, repository):
        """Test rpm packages count in package table."""
        cur = db_conn.cursor()
        cur.execute("select count(*) from package where source_package_id is not null")
        pkg_num = cur.fetchone()[0]

        assert pkg_num == 12

    @pytest.mark.parametrize("repository", REPOSITORIES, ids=[r[0] for r in REPOSITORIES])
    def test_srpm_pkgs_count(self, db_conn, repository):
        """Test srpm packages count in package table."""
        cur = db_conn.cursor()
        cur.execute("select count(*) from package where source_package_id is null")
        pkg_num = cur.fetchone()[0]

        assert pkg_num == 6

    def test_pkg_errata_count(self, db_conn):
        """Test that package - errata association are stored."""
        cur = db_conn.cursor()
        cur.execute("select count(*) from pkg_errata")
        cnt = cur.fetchone()[0]

        assert cnt == 9
예제 #5
0
class RepositoryController:
    """
    Class for importing/syncing set of repositories into the DB.
    First, repomd from all repositories are downloaded and parsed.
    Second, primary and updateinfo repodata from repositories needing update are downloaded, parsed and imported.
    """
    def __init__(self):
        self.logger = get_logger(__name__)
        self.downloader = FileDownloader()
        self.unpacker = FileUnpacker()
        self.repo_store = RepositoryStore()
        self.repositories = set()
        self.certs_tmp_directory = None
        self.certs_files = {}

    def _get_certs_tuple(self, name):
        if name in self.certs_files:
            return self.certs_files[name]["ca_cert"], self.certs_files[name]["cert"], self.certs_files[name]["key"]
        return None, None, None

    def _download_repomds(self):
        download_items = []
        for repository in self.repositories:
            repomd_url = urljoin(repository.repo_url, REPOMD_PATH)
            repository.tmp_directory = tempfile.mkdtemp(prefix="repo-")
            ca_cert, cert, key = self._get_certs_tuple(repository.cert_name)
            item = DownloadItem(
                source_url=repomd_url,
                target_path=os.path.join(repository.tmp_directory, "repomd.xml"),
                ca_cert=ca_cert,
                cert=cert,
                key=key
            )
            # Save for future status code check
            download_items.append(item)
            self.downloader.add(item)
        self.downloader.run()
        # Return failed downloads
        return {item.target_path: item.status_code for item in download_items
                if item.status_code not in VALID_HTTP_CODES}

    def _read_repomds(self):
        """Reads all downloaded repomd files. Checks if their download failed and checks if their metadata are
           newer than metadata currently in DB.
        """
        # Fetch current list of repositories from DB
        db_repositories = self.repo_store.list_repositories()
        for repository in self.repositories:
            repomd_path = os.path.join(repository.tmp_directory, "repomd.xml")
            repomd = RepoMD(repomd_path)
            # Was repository already synced before?
            repository_key = (repository.content_set, repository.basearch, repository.releasever)
            if repository_key in db_repositories:
                db_revision = db_repositories[repository_key]["revision"]
            else:
                db_revision = None
            downloaded_revision = repomd.get_revision()
            # Repository is synced for the first time or has newer revision
            if db_revision is None or downloaded_revision > db_revision:
                repository.repomd = repomd
            else:
                self.logger.info("Downloaded repo %s (%s) is not newer than repo in DB (%s).",
                                 ", ".join(filter(None, repository_key)), str(downloaded_revision), str(db_revision))

    def _repo_download_failed(self, repo, failed_items):
        failed = False
        for md_path in list(repo.md_files.values()) + [REPOMD_PATH]:
            local_path = os.path.join(repo.tmp_directory, os.path.basename(md_path))
            if local_path in failed_items:
                failed = True
                self.logger.warning("Download failed: %s (HTTP CODE %d)", urljoin(repo.repo_url, md_path),
                                    failed_items[local_path])
        return failed

    def _download_metadata(self, batch):
        download_items = []
        for repository in batch:
            # primary_db has higher priority, use primary.xml if not found
            try:
                repository.md_files["primary_db"] = repository.repomd.get_metadata("primary_db")["location"]
            except RepoMDTypeNotFound:
                repository.md_files["primary"] = repository.repomd.get_metadata("primary")["location"]
            # updateinfo.xml may be missing completely
            try:
                repository.md_files["updateinfo"] = repository.repomd.get_metadata("updateinfo")["location"]
            except RepoMDTypeNotFound:
                pass
            try:
                repository.md_files["modules"] = repository.repomd.get_metadata("modules")["location"]
            except RepoMDTypeNotFound:
                pass

            # queue metadata files for download
            for md_location in repository.md_files.values():
                ca_cert, cert, key = self._get_certs_tuple(repository.cert_name)
                item = DownloadItem(
                    source_url=urljoin(repository.repo_url, md_location),
                    target_path=os.path.join(repository.tmp_directory, os.path.basename(md_location)),
                    ca_cert=ca_cert,
                    cert=cert,
                    key=key
                )
                download_items.append(item)
                self.downloader.add(item)
        self.downloader.run()
        # Return failed downloads
        return {item.target_path: item.status_code for item in download_items
                if item.status_code not in VALID_HTTP_CODES}

    def _unpack_metadata(self, batch):
        for repository in batch:
            for md_type in repository.md_files:
                self.unpacker.add(os.path.join(repository.tmp_directory,
                                               os.path.basename(repository.md_files[md_type])))
                # FIXME: this should be done in different place?
                repository.md_files[md_type] = os.path.join(
                    repository.tmp_directory,
                    os.path.basename(repository.md_files[md_type])).rsplit(".", maxsplit=1)[0]
        self.unpacker.run()

    def clean_repodata(self, batch):
        """Clean downloaded repodata of all repositories in batch."""
        for repository in batch:
            if repository.tmp_directory:
                shutil.rmtree(repository.tmp_directory)
                repository.tmp_directory = None
            self.repositories.remove(repository)

    def _clean_certificate_cache(self):
        if self.certs_tmp_directory:
            shutil.rmtree(self.certs_tmp_directory)
            self.certs_tmp_directory = None
            self.certs_files = {}

    def add_db_repositories(self):
        """Queue all previously imported repositories."""
        repos = self.repo_store.list_repositories()
        for (content_set, basearch, releasever), repo_dict in repos.items():
            # Reference content_set_label -> content set id
            self.repo_store.content_set_to_db_id[content_set] = repo_dict["content_set_id"]
            self.repositories.add(Repository(repo_dict["url"], content_set, basearch, releasever,
                                             cert_name=repo_dict["cert_name"], ca_cert=repo_dict["ca_cert"],
                                             cert=repo_dict["cert"], key=repo_dict["key"]))

    def add_repository(self, repo_url, content_set, basearch, releasever,
                       cert_name=None, ca_cert=None, cert=None, key=None):
        """Queue repository to import/check updates."""
        repo_url = repo_url.strip()
        if not repo_url.endswith("/"):
            repo_url += "/"
        self.repositories.add(Repository(repo_url, content_set, basearch, releasever, cert_name=cert_name,
                                         ca_cert=ca_cert, cert=cert, key=key))

    def _write_certificate_cache(self):
        certs = {}
        for repository in self.repositories:
            if repository.cert_name:
                certs[repository.cert_name] = {"ca_cert": repository.ca_cert, "cert": repository.cert,
                                               "key": repository.key}
        if certs:
            self.certs_tmp_directory = tempfile.mkdtemp(prefix="certs-")
            for cert_name in certs:
                self.certs_files[cert_name] = {}
                for cert_type in ["ca_cert", "cert", "key"]:
                    # Cert is not None
                    if certs[cert_name][cert_type]:
                        cert_path = os.path.join(self.certs_tmp_directory, "%s.%s" % (cert_name, cert_type))
                        with open(cert_path, "w") as cert_file:
                            cert_file.write(certs[cert_name][cert_type])
                        self.certs_files[cert_name][cert_type] = cert_path
                    else:
                        self.certs_files[cert_name][cert_type] = None

    def _find_content_sets_by_regex(self, content_set_regex):
        if not content_set_regex.startswith('^'):
            content_set_regex = '^' + content_set_regex

        if not content_set_regex.endswith('$'):
            content_set_regex = content_set_regex + '$'

        return [content_set_label for content_set_label in self.repo_store.content_set_to_db_id
                if re.match(content_set_regex, content_set_label)]

    def delete_content_set(self, content_set_regex):
        """Deletes content sets described by given regex from DB."""
        for content_set_label in self._find_content_sets_by_regex(content_set_regex):
            self.logger.info("Deleting content set: %s", content_set_label)
            self.repo_store.delete_content_set(content_set_label)
        self.repo_store.cleanup_unused_data()

    def import_repositories(self):
        """Create or update repository records in the DB."""
        self.logger.info("Importing %d repositories.", len(self.repositories))
        for repository in self.repositories:
            self.repo_store.import_repository(repository)

    def store(self):
        """Sync all queued repositories. Process repositories in batches due to disk space and memory usage."""
        self.logger.info("Checking %d repositories.", len(self.repositories))

        self._write_certificate_cache()

        # Download all repomd files first
        failed = self._download_repomds()
        if failed:
            self.logger.warning("%d repomd.xml files failed to download.", len(failed))
            failed_repos = [repo for repo in self.repositories if self._repo_download_failed(repo, failed)]
            self.clean_repodata(failed_repos)

        self._read_repomds()
        # Filter all repositories without repomd attribute set (failed download, downloaded repomd is not newer)
        batches = BatchList()
        to_skip = []
        for repository in self.repositories:
            if repository.repomd:
                batches.add_item(repository)
            else:
                to_skip.append(repository)
        self.clean_repodata(to_skip)
        self.logger.info("%d repositories skipped.", len(to_skip))
        self.logger.info("Syncing %d repositories.", sum(len(l) for l in batches))

        # Download and process repositories in batches (unpacked metadata files can consume lot of disk space)
        for batch in batches:
            failed = self._download_metadata(batch)
            if failed:
                self.logger.warning("%d metadata files failed to download.", len(failed))
                failed_repos = [repo for repo in batch if self._repo_download_failed(repo, failed)]
                self.clean_repodata(failed_repos)
                batch = [repo for repo in batch if repo not in failed_repos]
            self._unpack_metadata(batch)
            for repository in batch:
                repository.load_metadata()
                self.repo_store.store(repository)
                repository.unload_metadata()
            self.clean_repodata(batch)

        self.repo_store.cleanup_unused_data()
        self._clean_certificate_cache()
예제 #6
0
class RepositoryController:
    """
    Class for importing/syncing set of repositories into the DB.
    First, repomd from all repositories are downloaded and parsed.
    Second, primary and updateinfo repodata from repositories needing update are downloaded, parsed and imported.
    """

    def __init__(self):
        self.logger = get_logger(__name__)
        self.downloader = FileDownloader()
        self.unpacker = FileUnpacker()
        self.repo_store = RepositoryStore()
        self.repositories = set()
        self.certs_tmp_directory = None
        self.certs_files = {}

    def _get_certs_tuple(self, name):
        if name in self.certs_files:
            return self.certs_files[name]["ca_cert"], self.certs_files[name]["cert"], self.certs_files[name]["key"]
        return None, None, None

    def _download_repomds(self):
        download_items = []
        certs_tmp_dict = {}
        for repository in self.repositories:
            repomd_url = urljoin(repository.repo_url, REPOMD_PATH)
            repository.tmp_directory = tempfile.mkdtemp(prefix="repo-")
            ca_cert, cert, key = self._get_certs_tuple(repository.cert_name)
            # Check certificate expiration date
            if repository.cert_name:
                certs_tmp_dict[repository.cert_name] = cert

            item = DownloadItem(
                source_url=repomd_url,
                target_path=os.path.join(repository.tmp_directory, "repomd.xml"),
                ca_cert=ca_cert,
                cert=cert,
                key=key
            )
            # Save for future status code check
            download_items.append(item)
            self.downloader.add(item)

        for cert_name, cert in certs_tmp_dict.items():
            self._check_cert_expiration_date(cert_name, cert)
        self.downloader.run()
        # Return failed downloads
        return {item.target_path: item.status_code for item in download_items
                if item.status_code not in VALID_HTTP_CODES}

    def _check_cert_expiration_date(self, cert_name, cert):
        try:
            # Load certificate
            loaded_cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
            # Get expiration date and parse it to datetime object
            valid_to_dt = datetime.strptime(loaded_cert.get_notAfter(), "%Y%m%d%H%M%SZ")
            expire_in_days_td = (valid_to_dt - datetime.utcnow()).days
            expire_tuple = (valid_to_dt, expire_in_days_td)
            if 30 >= expire_in_days_td > 0:
                self.logger.warning('Certificate %s will expire in %s', cert_name, expire_in_days_td)
                msg = prepare_msg_for_slack(cert_name, 'Reposcan CDN certificate will expire soon', expire_tuple)
                send_slack_notification(msg)
            else:
                self.logger.warning('Certificate %s expired!', cert_name)
                msg = prepare_msg_for_slack(cert_name, 'Reposcan CDN certificate expired', expire_tuple)
                send_slack_notification(msg)
        except crypto.Error:
            self.logger.warning('Certificate not provided or incorrect: %s', cert_name if cert_name else 'None')
            msg = prepare_msg_for_slack(cert_name, 'Reposcan CDN certificate not provided or incorrect')
            send_slack_notification(msg)

    def _read_repomds(self):
        """Reads all downloaded repomd files. Checks if their download failed and checks if their metadata are
           newer than metadata currently in DB.
        """
        # Fetch current list of repositories from DB
        db_repositories = self.repo_store.list_repositories()
        for repository in self.repositories:
            repomd_path = os.path.join(repository.tmp_directory, "repomd.xml")
            repomd = RepoMD(repomd_path)
            # Was repository already synced before?
            repository_key = (repository.content_set, repository.basearch, repository.releasever)
            if repository_key in db_repositories:
                db_revision = db_repositories[repository_key]["revision"]
            else:
                db_revision = None
            downloaded_revision = repomd.get_revision()
            # Repository is synced for the first time or has newer revision
            if db_revision is None or downloaded_revision > db_revision:
                repository.repomd = repomd
            else:
                self.logger.debug("Downloaded repo %s (%s) is not newer than repo in DB (%s).",
                                  ", ".join(filter(None, repository_key)), str(downloaded_revision), str(db_revision))

    def _repo_download_failed(self, repo, failed_items):
        failed = False
        for md_path in list(repo.md_files.values()) + [REPOMD_PATH]:
            local_path = os.path.join(repo.tmp_directory, os.path.basename(md_path))
            if local_path in failed_items:
                failed = True
                # Download errors with no HTTP code are logged in downloader, deduplicate error msgs
                if failed_items[local_path] > 0:
                    self.logger.warning("Download failed: LABEL: %s URL: %s (HTTP CODE %d)",
                                        repo.content_set, urljoin(repo.repo_url, md_path),
                                        failed_items[local_path])
                    FAILED_REPO_WITH_HTTP_CODE.labels(failed_items[local_path]).inc()
        return failed

    def _download_metadata(self, batch):
        download_items = []
        for repository in batch:
            # primary_db has higher priority, use primary.xml if not found
            try:
                repository.md_files["primary_db"] = repository.repomd.get_metadata("primary_db")["location"]
            except RepoMDTypeNotFound:
                repository.md_files["primary"] = repository.repomd.get_metadata("primary")["location"]
            # updateinfo.xml may be missing completely
            try:
                repository.md_files["updateinfo"] = repository.repomd.get_metadata("updateinfo")["location"]
            except RepoMDTypeNotFound:
                pass
            try:
                repository.md_files["modules"] = repository.repomd.get_metadata("modules")["location"]
            except RepoMDTypeNotFound:
                pass

            # queue metadata files for download
            for md_location in repository.md_files.values():
                ca_cert, cert, key = self._get_certs_tuple(repository.cert_name)
                item = DownloadItem(
                    source_url=urljoin(repository.repo_url, md_location),
                    target_path=os.path.join(repository.tmp_directory, os.path.basename(md_location)),
                    ca_cert=ca_cert,
                    cert=cert,
                    key=key
                )
                download_items.append(item)
                self.downloader.add(item)
        self.downloader.run()
        # Return failed downloads
        return {item.target_path: item.status_code for item in download_items
                if item.status_code not in VALID_HTTP_CODES}

    def _unpack_metadata(self, batch):
        for repository in batch:
            for md_type in repository.md_files:
                self.unpacker.add(os.path.join(repository.tmp_directory,
                                               os.path.basename(repository.md_files[md_type])))
                # FIXME: this should be done in different place?
                repository.md_files[md_type] = os.path.join(
                    repository.tmp_directory,
                    os.path.basename(repository.md_files[md_type])).rsplit(".", maxsplit=1)[0]
        self.unpacker.run()

    def clean_repodata(self, batch):
        """Clean downloaded repodata of all repositories in batch."""
        for repository in batch:
            if repository.tmp_directory:
                shutil.rmtree(repository.tmp_directory)
                repository.tmp_directory = None
            self.repositories.remove(repository)

    def _clean_certificate_cache(self):
        if self.certs_tmp_directory:
            shutil.rmtree(self.certs_tmp_directory)
            self.certs_tmp_directory = None
            self.certs_files = {}

    def add_db_repositories(self):
        """Queue all previously imported repositories."""
        repos = self.repo_store.list_repositories()
        for (content_set, basearch, releasever), repo_dict in repos.items():
            # Reference content_set_label -> content set id
            self.repo_store.content_set_to_db_id[content_set] = repo_dict["content_set_id"]
            self.repositories.add(Repository(repo_dict["url"], content_set, basearch, releasever,
                                             cert_name=repo_dict["cert_name"], ca_cert=repo_dict["ca_cert"],
                                             cert=repo_dict["cert"], key=repo_dict["key"]))

    def add_repository(self, repo_url, content_set, basearch, releasever,
                       cert_name=None, ca_cert=None, cert=None, key=None):
        """Queue repository to import/check updates."""
        repo_url = repo_url.strip()
        if not repo_url.endswith("/"):
            repo_url += "/"
        self.repositories.add(Repository(repo_url, content_set, basearch, releasever, cert_name=cert_name,
                                         ca_cert=ca_cert, cert=cert, key=key))

    def _write_certificate_cache(self):
        certs = {}
        for repository in self.repositories:
            if repository.cert_name:
                certs[repository.cert_name] = {"ca_cert": repository.ca_cert, "cert": repository.cert,
                                               "key": repository.key}
        if certs:
            self.certs_tmp_directory = tempfile.mkdtemp(prefix="certs-")
            for cert_name in certs:
                self.certs_files[cert_name] = {}
                for cert_type in ["ca_cert", "cert", "key"]:
                    # Cert is not None
                    if certs[cert_name][cert_type]:
                        cert_path = os.path.join(self.certs_tmp_directory, "%s.%s" % (cert_name, cert_type))
                        with open(cert_path, "w") as cert_file:
                            cert_file.write(certs[cert_name][cert_type])
                        self.certs_files[cert_name][cert_type] = cert_path
                    else:
                        self.certs_files[cert_name][cert_type] = None

    def _find_content_sets_by_regex(self, content_set_regex):
        if not content_set_regex.startswith('^'):
            content_set_regex = '^' + content_set_regex

        if not content_set_regex.endswith('$'):
            content_set_regex = content_set_regex + '$'

        return [content_set_label for content_set_label in self.repo_store.content_set_to_db_id
                if re.match(content_set_regex, content_set_label)]

    def delete_content_set(self, content_set_regex):
        """Deletes content sets described by given regex from DB."""
        for content_set_label in self._find_content_sets_by_regex(content_set_regex):
            self.logger.info("Deleting content set: %s", content_set_label)
            self.repo_store.delete_content_set(content_set_label)
        self.repo_store.cleanup_unused_data()

    def import_repositories(self):
        """Create or update repository records in the DB."""
        self.logger.info("Importing %d repositories.", len(self.repositories))
        failures = 0
        for repository in self.repositories:
            try:
                self.repo_store.import_repository(repository)
            except Exception:  # pylint: disable=broad-except
                failures += 1
        if failures > 0:
            self.logger.warning("Failed to import %d repositories.", failures)
            FAILED_IMPORT_REPO.inc(failures)

    def store(self):
        """Sync all queued repositories. Process repositories in batches due to disk space and memory usage."""
        self.logger.info("Checking %d repositories.", len(self.repositories))

        self._write_certificate_cache()

        # Download all repomd files first
        failed = self._download_repomds()
        if failed:
            FAILED_REPOMD.inc(len(failed))
            failed_repos = [repo for repo in self.repositories if self._repo_download_failed(repo, failed)]
            self.logger.warning("%d repomd.xml files failed to download.", len(failed))
            self.clean_repodata(failed_repos)

        self._read_repomds()
        # Filter all repositories without repomd attribute set (downloaded repomd is not newer)
        batches = BatchList()
        up_to_date = []

        def md_size(repomd, data_type):
            try:
                mdata = repomd.get_metadata(data_type)
                # open-size is not present for uncompressed files
                return int(mdata.get('size', 0)) + int(mdata.get('open-size', '0'))
            except RepoMDTypeNotFound:
                return 0

        for repository in self.repositories:
            if repository.repomd:

                repo_size = md_size(repository.repomd, 'primary_db')
                # If we use primary_db, we don't even download primary data xml
                if repo_size == 0:
                    repo_size += md_size(repository.repomd, 'primary')

                repo_size += md_size(repository.repomd, 'updateinfo')
                repo_size += md_size(repository.repomd, 'modules')

                batches.add_item(repository, repo_size)
            else:
                up_to_date.append(repository)

        self.clean_repodata(up_to_date)
        self.logger.info("%d repositories are up to date.", len(up_to_date))
        total_repositories = batches.get_total_items()
        completed_repositories = 0
        self.logger.info("%d repositories need to be synced.", total_repositories)

        # Download and process repositories in batches (unpacked metadata files can consume lot of disk space)
        try:
            for batch in batches:
                self.logger.info("Syncing a batch of %d repositories", len(batch))
                try:
                    failed = self._download_metadata(batch)
                    if failed:
                        self.logger.warning("%d metadata files failed to download.", len(failed))
                        failed_repos = [repo for repo in batch if self._repo_download_failed(repo, failed)]
                        self.clean_repodata(failed_repos)
                        batch = [repo for repo in batch if repo not in failed_repos]
                    self._unpack_metadata(batch)
                    for repository in batch:
                        repository.load_metadata()
                        completed_repositories += 1
                        self.logger.info("Syncing repository: %s [%s/%s]", ", ".join(
                            filter(None, (repository.content_set, repository.basearch, repository.releasever))),
                                         completed_repositories, total_repositories)
                        self.repo_store.store(repository)
                        repository.unload_metadata()
                finally:
                    self.clean_repodata(batch)
        finally:
            self.repo_store.cleanup_unused_data()
            self._clean_certificate_cache()
예제 #7
0
class TestRepositoryStore:
    """TestRepositoryStore class. Test repository store"""
    @pytest.fixture
    def repo_setup(self):
        """Setup repo_store object."""
        product_store = ProductStore()
        product_store.store(PRODUCTS)

        self.repo_store = RepositoryStore()

    @pytest.mark.first
    @pytest.mark.parametrize("repository",
                             REPOSITORIES,
                             ids=[r[0] for r in REPOSITORIES])
    def test_repo_store(self, db_conn, repo_setup, repository):
        """Test storing repository data in DB."""

        # update with updated = None result in IntegrityError
        if repository[1].updateinfo:
            for update in repository[1].updateinfo.updates:
                if update["updated"] is None:
                    update["updated"] = datetime.now(utc.UTC)

        # store repository
        self.repo_store.store(repository[1])

        cur = db_conn.cursor()
        cur.execute("select * from repo where url = '{}'".format(
            repository[1].repo_url))
        repo = cur.fetchone()
        cur.execute("select * from content_set where id = {}".format(
            repo[REPO_CS_ID]))
        content_set = cur.fetchone()
        cur.execute("select * from product where id = {}".format(
            content_set[CS_PRODUCT_ID]))
        product = cur.fetchone()
        cur.execute("select * from arch where id = {}".format(
            repo[REPO_BASEARCH_ID]))
        arch = cur.fetchone()

        assert repo[REPO_URL] == repository[1].repo_url
        assert repo[REPO_RELEASEVER] == repository[1].releasever

        assert content_set[CS_LABEL] == "cs_label"
        assert content_set[CS_NAME] == "cs_name"

        assert product[PRODUCT_NAME] == "product"
        assert product[PRODUCT_RH_ID] == 9

        assert arch[ARCH_NAME] == repository[1].basearch

    @pytest.mark.parametrize("repository",
                             REPOSITORIES,
                             ids=[r[0] for r in REPOSITORIES])
    def test_repo_pkgs(self, db_conn, repository):
        """Test that packages from repo are present in DB."""
        cur = db_conn.cursor()
        cur.execute("select id from repo where url = '{}'".format(
            repository[1].repo_url))
        repo_id = cur.fetchone()[0]
        cur.execute(
            "select count(*) from pkg_repo where repo_id = {}".format(repo_id))
        pkg_num = cur.fetchone()[0]

        assert pkg_num > 0

    @pytest.mark.parametrize("repository",
                             REPOSITORIES,
                             ids=[r[0] for r in REPOSITORIES])
    def test_repo_errata(self, db_conn, repository):
        """Test that errata from repo are present in DB."""
        cur = db_conn.cursor()
        cur.execute("select id from repo where url = '{}'".format(
            repository[1].repo_url))
        repo_id = cur.fetchone()[0]
        cur.execute(
            "select count(*) from errata_repo where repo_id = {}".format(
                repo_id))
        errata_num = cur.fetchone()[0]

        # only repository with updateifo has errata
        if repository[1].updateinfo:
            assert errata_num > 0
        else:
            assert errata_num == 0
예제 #8
0
class RepositoryController:
    """
    Class for importing/syncing set of repositories into the DB.
    First, repomd from all repositories are downloaded and parsed.
    Second, primary and updateinfo repodata from repositories needing update are downloaded, parsed and imported.
    """
    def __init__(self):
        self.logger = SimpleLogger()
        self.downloader = FileDownloader()
        self.unpacker = FileUnpacker()
        self.repo_store = RepositoryStore()
        self.repositories = set()
        self.db_repositories = {}

    def _download_repomds(self):
        download_items = []
        for repository in self.repositories:
            repomd_url = urljoin(repository.repo_url, REPOMD_PATH)
            repository.tmp_directory = tempfile.mkdtemp(prefix="repo-")
            item = DownloadItem(
                source_url=repomd_url,
                target_path=os.path.join(repository.tmp_directory, "repomd.xml")
            )
            # Save for future status code check
            download_items.append(item)
            self.downloader.add(item)
        self.downloader.run()
        # Return failed downloads
        return {item.target_path: item.status_code for item in download_items
                if item.status_code not in VALID_HTTP_CODES}

    def _read_repomds(self, failed):
        """Reads all downloaded repomd files. Checks if their download failed and checks if their metadata are
           newer than metadata currently in DB.
        """
        for repository in self.repositories:
            repomd_path = os.path.join(repository.tmp_directory, "repomd.xml")
            if repomd_path not in failed:
                repomd = RepoMD(repomd_path)
                # Was repository already synced before?
                if repository.repo_url in self.db_repositories:
                    db_revision = self.db_repositories[repository.repo_url]["revision"]
                else:
                    db_revision = None
                downloaded_revision = datetime.fromtimestamp(repomd.get_revision(), tz=timezone.utc)
                # Repository is synced for the first time or has newer revision
                if db_revision is None or downloaded_revision > db_revision:
                    repository.repomd = repomd
                else:
                    self.logger.log("Downloaded repo %s (%s) is not newer than repo in DB (%s)." %
                                    (repository.repo_url, str(downloaded_revision), str(db_revision)))
            else:
                self.logger.log("Download failed: %s (HTTP CODE %d)" % (urljoin(repository.repo_url, REPOMD_PATH),
                                                                        failed[repomd_path]))

    def _download_metadata(self, batch):
        for repository in batch:
            # primary_db has higher priority, use primary.xml if not found
            try:
                repository.md_files["primary_db"] = repository.repomd.get_metadata("primary_db")["location"]
            except RepoMDTypeNotFound:
                repository.md_files["primary"] = repository.repomd.get_metadata("primary")["location"]
            # updateinfo.xml may be missing completely
            try:
                repository.md_files["updateinfo"] = repository.repomd.get_metadata("updateinfo")["location"]
            except RepoMDTypeNotFound:
                pass

            # queue metadata files for download
            for md_location in repository.md_files.values():
                self.downloader.add(DownloadItem(
                    source_url=urljoin(repository.repo_url, md_location),
                    target_path=os.path.join(repository.tmp_directory, os.path.basename(md_location))
                ))
        self.downloader.run()

    def _unpack_metadata(self, batch):
        for repository in batch:
            for md_type in repository.md_files:
                self.unpacker.add(os.path.join(repository.tmp_directory,
                                               os.path.basename(repository.md_files[md_type])))
                # FIXME: this should be done in different place?
                repository.md_files[md_type] = os.path.join(
                    repository.tmp_directory,
                    os.path.basename(repository.md_files[md_type])).rsplit(".", maxsplit=1)[0]
        self.unpacker.run()

    def clean_repodata(self, batch):
        """Clean downloaded repodata of all repositories in batch."""
        for repository in batch:
            if repository.tmp_directory:
                shutil.rmtree(repository.tmp_directory)
                repository.tmp_directory = None
            self.repositories.remove(repository)

    def add_repository(self, repo_url):
        """Queue repository to import/check updates."""
        repo_url = repo_url.strip()
        if not repo_url.endswith("/"):
            repo_url += "/"
        self.repositories.add(Repository(repo_url))

    def store(self):
        """Sync all queued repositories. Process repositories in batches due to disk space and memory usage."""
        self.logger.log("Checking %d repositories." % len(self.repositories))

        # Fetch current list of repositories from DB
        self.db_repositories = self.repo_store.list_repositories()

        # Download all repomd files first
        failed = self._download_repomds()
        self.logger.log("%d repomd.xml files failed to download." % len(failed))
        self._read_repomds(failed)

        # Filter all repositories without repomd attribute set (failed download, downloaded repomd is not newer)
        batches = BatchList()
        to_skip = []
        for repository in self.repositories:
            if repository.repomd:
                batches.add_item(repository)
            else:
                to_skip.append(repository)
        self.clean_repodata(to_skip)
        self.logger.log("%d repositories skipped." % len(to_skip))
        self.logger.log("Syncing %d repositories." % sum(len(l) for l in batches))

        # Download and process repositories in batches (unpacked metadata files can consume lot of disk space)
        for batch in batches:
            self._download_metadata(batch)
            self._unpack_metadata(batch)
            for repository in batch:
                repository.load_metadata()
                self.repo_store.store(repository)
                repository.unload_metadata()
            self.clean_repodata(batch)