Exemple #1
0
 def setUpClass(cls):
     assert "fortesting" in str(cms.db.engine), \
         "Monkey patching of DB connection string failed"
     drop_db()
     init_db()
     cls.connection = cms.db.engine.connect()
     cms.db.metadata.create_all(cls.connection)
Exemple #2
0
 def setUpClass(cls):
     assert "fortesting" in str(cms.db.engine), \
         "Monkey patching of DB connection string failed"
     drop_db()
     init_db()
     cls.connection = cms.db.engine.connect()
     cms.db.metadata.create_all(cls.connection)
Exemple #3
0
 def _prepare_db(self):
     logger.info("Creating database structure.")
     if self.drop:
         try:
             if not (drop_db() and init_db()):
                 return False
         except Exception as error:
             logger.critical("Unable to access DB.\n%r" % error)
             return False
     return True
Exemple #4
0
 def _prepare_db(self):
     logger.info("Creating database structure.")
     if self.drop:
         try:
             if not (drop_db() and init_db()):
                 logger.critical("Unexpected error while dropping and "
                                 "recreating the database.",
                                 exc_info=True)
                 return False
         except Exception as error:
             logger.critical("Unable to access DB.\n%r" % error)
             return False
     return True
Exemple #5
0
 def _prepare_db(self):
     logger.info("Creating database structure.")
     if self.drop:
         try:
             if not (drop_db() and init_db()):
                 logger.critical(
                     "Unexpected error while dropping and "
                     "recreating the database.",
                     exc_info=True)
                 return False
         except Exception as error:
             logger.critical("Unable to access DB.\n%r" % error)
             return False
     return True
Exemple #6
0
    def do_import(self):
        """Run the actual import code."""
        logger.info("Starting import.")

        archive = None
        if Archive.is_supported(self.import_source):
            archive = Archive(self.import_source)
            self.import_dir = archive.unpack()

            file_names = os.listdir(self.import_dir)
            if len(file_names) != 1:
                logger.critical("Cannot find a root directory in %s.",
                                self.import_source)
                archive.cleanup()
                return False

            self.import_dir = os.path.join(self.import_dir, file_names[0])

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                if not (drop_db() and init_db()):
                    logger.critical("Unexpected error while dropping "
                                    "and recreating the database.",
                                    exc_info=True)
                    return False
            except Exception:
                logger.critical("Unable to access DB.", exc_info=True)
                return False

        with SessionGen() as session:

            # Import the contest in JSON format.
            if self.load_model:
                logger.info("Importing the contest from a JSON file.")

                with io.open(os.path.join(self.import_dir,
                                          "contest.json"), "rb") as fin:
                    # TODO - Throughout all the code we'll assume the
                    # input is correct without actually doing any
                    # validations.  Thus, for example, we're not
                    # checking that the decoded object is a dict...
                    self.datas = json.load(fin)

                # If the dump has been exported using a data model
                # different than the current one (that is, a previous
                # one) we try to update it.
                # If no "_version" field is found we assume it's a v1.0
                # export (before the new dump format was introduced).
                dump_version = self.datas.get("_version", 0)

                if dump_version < model_version:
                    logger.warning(
                        "The dump you're trying to import has been created "
                        "by an old version of CMS (it declares data model "
                        "version %d). It may take a while to adapt it to "
                        "the current data model (which is version %d). You "
                        "can use cmsDumpUpdater to update the on-disk dump "
                        "and speed up future imports.",
                        dump_version, model_version)

                elif dump_version > model_version:
                    logger.critical(
                        "The dump you're trying to import has been created "
                        "by a version of CMS newer than this one (it "
                        "declares data model version %d) and there is no "
                        "way to adapt it to the current data model (which "
                        "is version %d). You probably need to update CMS to "
                        "handle it. It is impossible to proceed with the "
                        "importation.", dump_version, model_version)
                    return False

                else:
                    logger.info(
                        "Importing dump with data model version %d.",
                        dump_version)

                for version in range(dump_version, model_version):
                    # Update from version to version+1
                    updater = __import__(
                        "cmscontrib.updaters.update_%d" % (version + 1),
                        globals(), locals(), ["Updater"]).Updater(self.datas)
                    self.datas = updater.run()
                    self.datas["_version"] = version + 1

                assert self.datas["_version"] == model_version

                self.objs = dict()
                for id_, data in iteritems(self.datas):
                    if not id_.startswith("_"):
                        self.objs[id_] = self.import_object(data)
                for id_, data in iteritems(self.datas):
                    if not id_.startswith("_"):
                        self.add_relationships(data, self.objs[id_])

                for k, v in list(iteritems(self.objs)):

                    # Skip submissions if requested
                    if self.skip_submissions and isinstance(v, Submission):
                        del self.objs[k]

                    # Skip user_tests if requested
                    if self.skip_user_tests and isinstance(v, UserTest):
                        del self.objs[k]

                    # Skip print jobs if requested
                    if self.skip_print_jobs and isinstance(v, PrintJob):
                        del self.objs[k]

                    # Skip generated data if requested
                    if self.skip_generated and \
                            isinstance(v, (SubmissionResult, UserTestResult)):
                        del self.objs[k]

                contest_id = list()
                contest_files = set()

                # We add explicitly only the top-level objects:
                # contests, and tasks and users not contained in any
                # contest. This will add on cascade all dependent
                # objects, and not add orphaned objects (like those
                # that depended on submissions or user tests that we
                # might have removed above).
                for id_ in self.datas["_objects"]:
                    obj = self.objs[id_]
                    session.add(obj)
                    session.flush()

                    if isinstance(obj, Contest):
                        contest_id += [obj.id]
                        contest_files |= enumerate_files(
                            session, obj,
                            skip_submissions=self.skip_submissions,
                            skip_user_tests=self.skip_user_tests,
                            skip_print_jobs=self.skip_print_jobs,
                            skip_generated=self.skip_generated)

                session.commit()
            else:
                contest_id = None
                contest_files = None

            # Import files.
            if self.load_files:
                logger.info("Importing files.")

                files_dir = os.path.join(self.import_dir, "files")
                descr_dir = os.path.join(self.import_dir, "descriptions")

                files = set(os.listdir(files_dir))
                descr = set(os.listdir(descr_dir))

                if not descr <= files:
                    logger.warning("Some files do not have an associated "
                                   "description.")
                if not files <= descr:
                    logger.warning("Some descriptions do not have an "
                                   "associated file.")

                if not (contest_files is None or files <= contest_files):
                    # FIXME Check if it's because this is a light import
                    # or because we're skipping submissions or user_tests
                    logger.warning("The dump contains some files that are "
                                   "not needed by the contest.")
                if not (contest_files is None or contest_files <= files):
                    # The reason for this could be that it was a light
                    # export that's not being reimported as such.
                    logger.warning("The contest needs some files that are "
                                   "not contained in the dump.")

                # Limit import to files we actually need.
                if contest_files is not None:
                    files &= contest_files

                for digest in files:
                    file_ = os.path.join(files_dir, digest)
                    desc = os.path.join(descr_dir, digest)
                    if not self.safe_put_file(file_, desc):
                        logger.critical("Unable to put file `%s' in the DB. "
                                        "Aborting. Please remove the contest "
                                        "from the database.", file_)
                        # TODO: remove contest from the database.
                        return False

        # Clean up, if an archive was used
        if archive is not None:
            archive.cleanup()

        if contest_id is not None:
            logger.info("Import finished (contest id: %s).",
                        ", ".join("%d" % id_ for id_ in contest_id))
        else:
            logger.info("Import finished.")

        return True
Exemple #7
0
    def do_import(self):
        """Run the actual import code."""
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            elif self.import_source.endswith(".tar.xz") \
                    or self.import_source.endswith(".txz"):
                try:
                    import lzma
                except ImportError:
                    logger.critical("LZMA compression format not "
                                    "supported. Please install package "
                                    "lzma.")
                    return False
                archive = tarfile.open(
                    fileobj=lzma.LZMAFile(self.import_source))
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                if not (drop_db() and init_db()):
                    logger.critical("Unexpected error while dropping "
                                    "and recreating the database.",
                                    exc_info=True)
                    return False
            except Exception as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False

        with SessionGen() as session:

            # Import the contest in JSON format.
            if self.load_model:
                logger.info("Importing the contest from a JSON file.")

                with io.open(os.path.join(self.import_dir,
                                          "contest.json"), "rb") as fin:
                    # TODO - Throughout all the code we'll assume the
                    # input is correct without actually doing any
                    # validations.  Thus, for example, we're not
                    # checking that the decoded object is a dict...
                    self.datas = json.load(fin, encoding="utf-8")

                # If the dump has been exported using a data model
                # different than the current one (that is, a previous
                # one) we try to update it.
                # If no "_version" field is found we assume it's a v1.0
                # export (before the new dump format was introduced).
                dump_version = self.datas.get("_version", 0)

                if dump_version < model_version:
                    logger.warning(
                        "The dump you're trying to import has been created "
                        "by an old version of CMS. It may take a while to "
                        "adapt it to the current data model. You can use "
                        "cmsDumpUpdater to update the on-disk dump and "
                        "speed up future imports.")

                if dump_version > model_version:
                    logger.critical(
                        "The dump you're trying to import has been created "
                        "by a version of CMS newer than this one and there "
                        "is no way to adapt it to the current data model. "
                        "You probably need to update CMS to handle it. It's "
                        "impossible to proceed with the importation.")
                    return False

                for version in range(dump_version, model_version):
                    # Update from version to version+1
                    updater = __import__(
                        "cmscontrib.updaters.update_%d" % (version + 1),
                        globals(), locals(), ["Updater"]).Updater(self.datas)
                    self.datas = updater.run()
                    self.datas["_version"] = version + 1

                assert self.datas["_version"] == model_version

                self.objs = dict()
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.objs[id_] = self.import_object(data)
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.add_relationships(data, self.objs[id_])

                for k, v in list(self.objs.iteritems()):

                    # Skip submissions if requested
                    if self.skip_submissions and isinstance(v, Submission):
                        del self.objs[k]

                    # Skip user_tests if requested
                    if self.skip_user_tests and isinstance(v, UserTest):
                        del self.objs[k]

                    # Skip generated data if requested
                    if self.skip_generated and \
                            isinstance(v, (SubmissionResult, UserTestResult)):
                        del self.objs[k]

                contest_id = list()
                contest_files = set()

                # Add each base object and all its dependencies
                for id_ in self.datas["_objects"]:
                    contest = self.objs[id_]

                    # We explictly add only the contest since all child
                    # objects will be automatically added by cascade.
                    # Adding each object individually would also add
                    # orphaned objects like the ones that depended on
                    # submissions or user_tests that we (possibly)
                    # removed above.
                    session.add(contest)
                    session.flush()

                    contest_id += [contest.id]
                    contest_files |= contest.enumerate_files(
                        self.skip_submissions, self.skip_user_tests, self.skip_generated)

                session.commit()
            else:
                contest_id = None
                contest_files = None

            # Import files.
            if self.load_files:
                logger.info("Importing files.")

                files_dir = os.path.join(self.import_dir, "files")
                descr_dir = os.path.join(self.import_dir, "descriptions")

                files = set(os.listdir(files_dir))
                descr = set(os.listdir(descr_dir))

                if not descr <= files:
                    logger.warning("Some files do not have an associated "
                                   "description.")
                if not files <= descr:
                    logger.warning("Some descriptions do not have an "
                                   "associated file.")

                if not (contest_files is None or files <= contest_files):
                    # FIXME Check if it's because this is a light import
                    # or because we're skipping submissions or user_tests
                    logger.warning("The dump contains some files that are "
                                   "not needed by the contest.")
                if not (contest_files is None or contest_files <= files):
                    # The reason for this could be that it was a light
                    # export that's not being reimported as such.
                    logger.warning("The contest needs some files that are "
                                   "not contained in the dump.")

                # Limit import to files we actually need.
                if contest_files is not None:
                    files &= contest_files

                for digest in files:
                    file_ = os.path.join(files_dir, digest)
                    desc = os.path.join(descr_dir, digest)
                    if not self.safe_put_file(file_, desc):
                        logger.critical("Unable to put file `%s' in the database. "
                                        "Aborting. Please remove the contest "
                                        "from the database." % file_)
                        # TODO: remove contest from the database.
                        return False


        if contest_id is not None:
            logger.info("Import finished (contest id: %s)." %
                        ", ".join(str(id_) for id_ in contest_id))
        else:
            logger.info("Import finished.")

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            rmtree(self.import_dir)

        return True
Exemple #8
0
    def do_import(self):
        """Run the actual import code."""
        logger.info("Starting import.")

        archive = None
        if Archive.is_supported(self.import_source):
            archive = Archive(self.import_source)
            self.import_dir = archive.unpack()

            file_names = os.listdir(self.import_dir)
            if len(file_names) != 1:
                logger.critical("Cannot find a root directory in %s.",
                                self.import_source)
                archive.cleanup()
                return False

            self.import_dir = os.path.join(self.import_dir, file_names[0])

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                if not (drop_db() and init_db()):
                    logger.critical(
                        "Unexpected error while dropping "
                        "and recreating the database.",
                        exc_info=True)
                    return False
            except Exception:
                logger.critical("Unable to access DB.", exc_info=True)
                return False

        with SessionGen() as session:

            # Import the contest in JSON format.
            if self.load_model:
                logger.info("Importing the contest from a JSON file.")

                with io.open(os.path.join(self.import_dir, "contest.json"),
                             "rb") as fin:
                    # TODO - Throughout all the code we'll assume the
                    # input is correct without actually doing any
                    # validations.  Thus, for example, we're not
                    # checking that the decoded object is a dict...
                    self.datas = json.load(fin)

                # If the dump has been exported using a data model
                # different than the current one (that is, a previous
                # one) we try to update it.
                # If no "_version" field is found we assume it's a v1.0
                # export (before the new dump format was introduced).
                dump_version = self.datas.get("_version", 0)

                if dump_version < model_version:
                    logger.warning(
                        "The dump you're trying to import has been created "
                        "by an old version of CMS (it declares data model "
                        "version %d). It may take a while to adapt it to "
                        "the current data model (which is version %d). You "
                        "can use cmsDumpUpdater to update the on-disk dump "
                        "and speed up future imports.", dump_version,
                        model_version)

                elif dump_version > model_version:
                    logger.critical(
                        "The dump you're trying to import has been created "
                        "by a version of CMS newer than this one (it "
                        "declares data model version %d) and there is no "
                        "way to adapt it to the current data model (which "
                        "is version %d). You probably need to update CMS to "
                        "handle it. It is impossible to proceed with the "
                        "importation.", dump_version, model_version)
                    return False

                else:
                    logger.info("Importing dump with data model version %d.",
                                dump_version)

                for version in range(dump_version, model_version):
                    # Update from version to version+1
                    updater = __import__(
                        "cmscontrib.updaters.update_%d" % (version + 1),
                        globals(), locals(), ["Updater"]).Updater(self.datas)
                    self.datas = updater.run()
                    self.datas["_version"] = version + 1

                assert self.datas["_version"] == model_version

                self.objs = dict()
                for id_, data in iteritems(self.datas):
                    if not id_.startswith("_"):
                        self.objs[id_] = self.import_object(data)
                for id_, data in iteritems(self.datas):
                    if not id_.startswith("_"):
                        self.add_relationships(data, self.objs[id_])

                for k, v in list(iteritems(self.objs)):

                    # Skip submissions if requested
                    if self.skip_submissions and isinstance(v, Submission):
                        del self.objs[k]

                    # Skip user_tests if requested
                    if self.skip_user_tests and isinstance(v, UserTest):
                        del self.objs[k]

                    # Skip print jobs if requested
                    if self.skip_print_jobs and isinstance(v, PrintJob):
                        del self.objs[k]

                    # Skip generated data if requested
                    if self.skip_generated and \
                            isinstance(v, (SubmissionResult, UserTestResult)):
                        del self.objs[k]

                contest_id = list()
                contest_files = set()

                # We add explicitly only the top-level objects:
                # contests, and tasks and users not contained in any
                # contest. This will add on cascade all dependent
                # objects, and not add orphaned objects (like those
                # that depended on submissions or user tests that we
                # might have removed above).
                for id_ in self.datas["_objects"]:
                    obj = self.objs[id_]
                    session.add(obj)
                    session.flush()

                    if isinstance(obj, Contest):
                        contest_id += [obj.id]
                        contest_files |= enumerate_files(
                            session,
                            obj,
                            skip_submissions=self.skip_submissions,
                            skip_user_tests=self.skip_user_tests,
                            skip_print_jobs=self.skip_print_jobs,
                            skip_generated=self.skip_generated)

                session.commit()
            else:
                contest_id = None
                contest_files = None

            # Import files.
            if self.load_files:
                logger.info("Importing files.")

                files_dir = os.path.join(self.import_dir, "files")
                descr_dir = os.path.join(self.import_dir, "descriptions")

                files = set(os.listdir(files_dir))
                descr = set(os.listdir(descr_dir))

                if not descr <= files:
                    logger.warning("Some files do not have an associated "
                                   "description.")
                if not files <= descr:
                    logger.warning("Some descriptions do not have an "
                                   "associated file.")

                if not (contest_files is None or files <= contest_files):
                    # FIXME Check if it's because this is a light import
                    # or because we're skipping submissions or user_tests
                    logger.warning("The dump contains some files that are "
                                   "not needed by the contest.")
                if not (contest_files is None or contest_files <= files):
                    # The reason for this could be that it was a light
                    # export that's not being reimported as such.
                    logger.warning("The contest needs some files that are "
                                   "not contained in the dump.")

                # Limit import to files we actually need.
                if contest_files is not None:
                    files &= contest_files

                for digest in files:
                    file_ = os.path.join(files_dir, digest)
                    desc = os.path.join(descr_dir, digest)
                    if not self.safe_put_file(file_, desc):
                        logger.critical(
                            "Unable to put file `%s' in the DB. "
                            "Aborting. Please remove the contest "
                            "from the database.", file_)
                        # TODO: remove contest from the database.
                        return False

        # Clean up, if an archive was used
        if archive is not None:
            archive.cleanup()

        if contest_id is not None:
            logger.info("Import finished (contest id: %s).",
                        ", ".join("%d" % id_ for id_ in contest_id))
        else:
            logger.info("Import finished.")

        return True
Exemple #9
0
 def tearDownClass(cls):
     drop_db()
     cls.connection.close()
     cms.db.engine.dispose()
Exemple #10
0
 def setUpClass(cls):
     super(DatabaseMixin, cls).setUpClass()
     assert "fortesting" in str(engine), \
         "Monkey patching of DB connection string failed"
     drop_db()
     init_db()
Exemple #11
0
 def tearDownClass(cls):
     drop_db()
     cls.connection.close()
     cms.db.engine.dispose()
     super(DatabaseMixin, cls).tearDownClass()
Exemple #12
0
 def tearDownClass(cls):
     drop_db()
     super().tearDownClass()
Exemple #13
0
 def setUpClass(cls):
     super().setUpClass()
     assert "fortesting" in str(engine), \
         "Monkey patching of DB connection string failed"
     drop_db()
     init_db()
Exemple #14
0
    def do_import(self):
        """Run the actual import code."""
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            elif self.import_source.endswith(".tar.xz") \
                    or self.import_source.endswith(".txz"):
                try:
                    import lzma
                except ImportError:
                    logger.critical("LZMA compression format not "
                                    "supported. Please install package "
                                    "lzma.")
                    return False
                archive = tarfile.open(
                    fileobj=lzma.LZMAFile(self.import_source))
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                if not (drop_db() and init_db()):
                    logger.critical(
                        "Unexpected error while dropping "
                        "and recreating the database.",
                        exc_info=True)
                    return False
            except Exception as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False

        with SessionGen() as session:

            # Import the contest in JSON format.
            if self.load_model:
                logger.info("Importing the contest from a JSON file.")

                with io.open(os.path.join(self.import_dir, "contest.json"),
                             "rb") as fin:
                    # TODO - Throughout all the code we'll assume the
                    # input is correct without actually doing any
                    # validations.  Thus, for example, we're not
                    # checking that the decoded object is a dict...
                    self.datas = json.load(fin, encoding="utf-8")

                # If the dump has been exported using a data model
                # different than the current one (that is, a previous
                # one) we try to update it.
                # If no "_version" field is found we assume it's a v1.0
                # export (before the new dump format was introduced).
                dump_version = self.datas.get("_version", 0)

                if dump_version < model_version:
                    logger.warning(
                        "The dump you're trying to import has been created "
                        "by an old version of CMS. It may take a while to "
                        "adapt it to the current data model. You can use "
                        "cmsDumpUpdater to update the on-disk dump and "
                        "speed up future imports.")

                if dump_version > model_version:
                    logger.critical(
                        "The dump you're trying to import has been created "
                        "by a version of CMS newer than this one and there "
                        "is no way to adapt it to the current data model. "
                        "You probably need to update CMS to handle it. It's "
                        "impossible to proceed with the importation.")
                    return False

                for version in range(dump_version, model_version):
                    # Update from version to version+1
                    updater = __import__(
                        "cmscontrib.updaters.update_%d" % (version + 1),
                        globals(), locals(), ["Updater"]).Updater(self.datas)
                    self.datas = updater.run()
                    self.datas["_version"] = version + 1

                assert self.datas["_version"] == model_version

                self.objs = dict()
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.objs[id_] = self.import_object(data)
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.add_relationships(data, self.objs[id_])

                for k, v in list(self.objs.iteritems()):

                    # Skip submissions if requested
                    if self.skip_submissions and isinstance(v, Submission):
                        del self.objs[k]

                    # Skip user_tests if requested
                    if self.skip_user_tests and isinstance(v, UserTest):
                        del self.objs[k]

                    # Skip generated data if requested
                    if self.skip_generated and \
                            isinstance(v, (SubmissionResult, UserTestResult)):
                        del self.objs[k]

                contest_id = list()
                contest_files = set()

                # Add each base object and all its dependencies
                for id_ in self.datas["_objects"]:
                    contest = self.objs[id_]

                    # We explictly add only the contest since all child
                    # objects will be automatically added by cascade.
                    # Adding each object individually would also add
                    # orphaned objects like the ones that depended on
                    # submissions or user_tests that we (possibly)
                    # removed above.
                    session.add(contest)
                    session.flush()

                    contest_id += [contest.id]
                    contest_files |= contest.enumerate_files(
                        self.skip_submissions, self.skip_user_tests,
                        self.skip_generated)

                session.commit()
            else:
                contest_id = None
                contest_files = None

            # Import files.
            if self.load_files:
                logger.info("Importing files.")

                files_dir = os.path.join(self.import_dir, "files")
                descr_dir = os.path.join(self.import_dir, "descriptions")

                files = set(os.listdir(files_dir))
                descr = set(os.listdir(descr_dir))

                if not descr <= files:
                    logger.warning("Some files do not have an associated "
                                   "description.")
                if not files <= descr:
                    logger.warning("Some descriptions do not have an "
                                   "associated file.")

                if not (contest_files is None or files <= contest_files):
                    # FIXME Check if it's because this is a light import
                    # or because we're skipping submissions or user_tests
                    logger.warning("The dump contains some files that are "
                                   "not needed by the contest.")
                if not (contest_files is None or contest_files <= files):
                    # The reason for this could be that it was a light
                    # export that's not being reimported as such.
                    logger.warning("The contest needs some files that are "
                                   "not contained in the dump.")

                # Limit import to files we actually need.
                if contest_files is not None:
                    files &= contest_files

                for digest in files:
                    file_ = os.path.join(files_dir, digest)
                    desc = os.path.join(descr_dir, digest)
                    if not self.safe_put_file(file_, desc):
                        logger.critical(
                            "Unable to put file `%s' in the database. "
                            "Aborting. Please remove the contest "
                            "from the database." % file_)
                        # TODO: remove contest from the database.
                        return False

        if contest_id is not None:
            logger.info("Import finished (contest id: %s)." %
                        ", ".join(str(id_) for id_ in contest_id))
        else:
            logger.info("Import finished.")

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            rmtree(self.import_dir)

        return True
Exemple #15
0
 def tearDownClass(cls):
     drop_db()
     super(DatabaseMixin, cls).tearDownClass()
Exemple #16
0
 def tearDownClass(cls):
     drop_db()
     cls.connection.close()
     cms.db.engine.dispose()
Exemple #17
0
 def tearDownClass(cls):
     drop_db()
     super(DatabaseMixin, cls).tearDownClass()
Exemple #18
0
 def setUpClass(cls):
     cms.db.engine = create_engine(cms.config.database)
     drop_db()
     init_db()
     cls.connection = cms.db.engine.connect()
     cms.db.metadata.create_all(cls.connection)