Example #1
0
    def do_import(self):
        """Take care of creating the database structure, delegating
        the loading of the contest data and putting them on the
        database.

        """
        logger.info("Creating database structure.")
        if self.drop:
            try:
                with SessionGen() as session:
                    FSObject.delete_all(session)
                    session.commit()
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        contest = Contest.import_from_dict(self.loader.import_contest(self.path))

        logger.info("Creating contest on the database.")
        with SessionGen() as session:
            session.add(contest)
            logger.info("Analyzing database.")
            session.commit()
            contest_id = contest.id
            analyze_all_tables(session)

        logger.info("Import finished (new contest id: %s)." % contest_id)

        return True
Example #2
0
def ask_for_contest(skip=None):
    """Print a greeter that ask the user for a contest, if there is
    not an indication of which contest to use in the command line.

    skip (int/None): how many commandline arguments are already taken
                     by other usages.

    return (int): a contest_id.

    """
    # It seems nice to create the tables first if they do not exists.
    try:
        metadata.create_all()
    except sqlalchemy.exc.OperationalError as error:
        print "Unable to create DB.\n" \
              "%r" % error
        raise sqlalchemy.exc.OperationalError()

    if isinstance(skip, int) and len(sys.argv) > skip + 1:
        contest_id = int(sys.argv[skip + 1])

    else:

        with SessionGen(commit=False) as session:
            contests = get_contest_list(session)
            # The ids of the contests are cached, so the session can
            # be closed as soon as possible
            matches = {}
            n_contests = len(contests)
            if n_contests == 0:
                print "No contests in the database."
                print "You may want to use some of the facilities in " \
                      "cmscontrib to import a contest."
                sys.exit(0)
            print "Contests available:"
            for i, row in enumerate(contests):
                print "%3d  -  ID: %d  -  Name: %s  -  Description: %s" % \
                      (i + 1, row.id, row.name, row.description),
                matches[i + 1] = row.id
                if i == n_contests - 1:
                    print " (default)"
                else:
                    print

        contest_number = raw_input("Insert the row number next to the contest "
                                   "you want to load (not the id): ")
        if contest_number == "":
            contest_number = n_contests
        try:
            contest_id = matches[int(contest_number)]
        except (ValueError, KeyError):
            print "Insert a correct number."
            sys.exit(1)

    return contest_id
Example #3
0
 def _prepare_db(self):
     logger.info("Creating database structure.")
     if self.drop:
         try:
             drop_everything()
         except sqlalchemy.exc.OperationalError as error:
             logger.critical("Unable to access DB.\n%r" % error)
             return False
     try:
         metadata.create_all()
     except sqlalchemy.exc.OperationalError as error:
         logger.critical("Unable to access DB.\n%r" % error)
         return False
     return True
Example #4
0
 def _prepare_db(self):
     logger.info("Creating database structure.")
     if self.drop:
         try:
             with SessionGen() as session:
                 FSObject.delete_all(session)
                 session.commit()
             metadata.drop_all()
         except sqlalchemy.exc.OperationalError as error:
             logger.critical("Unable to access DB.\n%r" % error)
             return False
     try:
         metadata.create_all()
     except sqlalchemy.exc.OperationalError as error:
         logger.critical("Unable to access DB.\n%r" % error)
         return False
    def do_import(self):
        """Take care of creating the database structure, delegating
        the loading of the contest data and putting them on the
        database.

        """
        logger.info("Creating database structure.")
        if self.drop:
            try:
                with SessionGen() as session:
                    FSObject.delete_all(session)
                    session.commit()
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        contest = Contest.import_from_dict(
            self.loader.import_contest(self.path))

        logger.info("Creating contest on the database.")
        with SessionGen() as session:
            session.add(contest)
            logger.info("Analyzing database.")
            session.commit()
            contest_id = contest.id
            analyze_all_tables(session)

        logger.info("Import finished (new contest id: %s)." % contest_id)

        return True
Example #6
0
    def do_import(self):
        """Run the actual import code."""
        logger.operation = "importing contest from %s" % self.import_source
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            elif self.import_source.endswith(".tar.xz") \
                    or self.import_source.endswith(".txz"):
                try:
                    import lzma
                except ImportError:
                    logger.critical("LZMA compression format not "
                                    "supported. Please install package "
                                    "lzma.")
                    return False
                archive = tarfile.open(
                    fileobj=lzma.LZMAFile(self.import_source))
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        with SessionGen(commit=False) as session:

            # Import the contest in JSON format.
            if self.load_model:
                logger.info("Importing the contest from a JSON file.")

                with io.open(os.path.join(self.import_dir,
                                          "contest.json"), "rb") as fin:
                    # TODO - Throughout all the code we'll assume the
                    # input is correct without actually doing any
                    # validations.  Thus, for example, we're not
                    # checking that the decoded object is a dict...
                    self.datas = json.load(fin, encoding="utf-8")

                # If the dump has been exported using a data model
                # different than the current one (that is, a previous
                # one) we try to update it.
                # If no "_version" field is found we assume it's a v1.0
                # export (before the new dump format was introduced).
                dump_version = self.datas.get("_version", 0)

                if dump_version < model_version:
                    logger.warning(
                        "The dump you're trying to import has been created "
                        "by an old version of CMS. It may take a while to "
                        "adapt it to the current data model. You can use "
                        "cmsDumpUpdater to update the on-disk dump and "
                        "speed up future imports.")

                if dump_version > model_version:
                    logger.critical(
                        "The dump you're trying to import has been created "
                        "by a version of CMS newer than this one and there "
                        "is no way to adapt it to the current data model. "
                        "You probably need to update CMS to handle it. It's "
                        "impossible to proceed with the importation.")
                    return False

                for version in range(dump_version, model_version):
                    # Update from version to version+1
                    updater = __import__(
                        "cmscontrib.updaters.update_%d" % (version + 1),
                        globals(), locals(), ["Updater"]).Updater(self.datas)
                    self.datas = updater.run()
                    self.datas["_version"] = version + 1

                assert self.datas["_version"] == model_version

                self.objs = dict()
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.objs[id_] = self.import_object(data)
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.add_relationships(data, self.objs[id_])

                for k, v in list(self.objs.iteritems()):

                    # Skip submissions if requested
                    if self.skip_submissions and isinstance(v, Submission):
                        del self.objs[k]

                    # Skip user_tests if requested
                    if self.skip_user_tests and isinstance(v, UserTest):
                        del self.objs[k]

                contest_id = list()
                contest_files = set()

                # Add each base object and all its dependencies
                for id_ in self.datas["_objects"]:
                    contest = self.objs[id_]

                    # We explictly add only the contest since all child
                    # objects will be automatically added by cascade.
                    # Adding each object individually would also add
                    # orphaned objects like the ones that depended on
                    # submissions or user_tests that we (possibly)
                    # removed above.
                    session.add(contest)
                    session.flush()

                    contest_id += [contest.id]
                    contest_files |= contest.enumerate_files(
                        self.skip_submissions, self.skip_user_tests, self.light)

                session.commit()
            else:
                contest_id = None
                contest_files = None

            # Import files.
            if self.load_files:
                logger.info("Importing files.")

                files_dir = os.path.join(self.import_dir, "files")
                descr_dir = os.path.join(self.import_dir, "descriptions")

                files = set(os.listdir(files_dir))
                descr = set(os.listdir(descr_dir))

                if not descr <= files:
                    logger.warning("Some files do not have an associated "
                                   "description.")
                if not files <= descr:
                    logger.warning("Some descriptions do not have an "
                                   "associated file.")

                if not (contest_files is None or files <= contest_files):
                    # FIXME Check if it's because this is a light import
                    # or because we're skipping submissions or user_tests
                    logger.warning("The dump contains some files that are "
                                   "not needed by the contest.")
                if not (contest_files is None or contest_files <= files):
                    # The reason for this could be that it was a light
                    # export that's not being reimported as such.
                    logger.warning("The contest needs some files that are "
                                   "not contained in the dump.")

                # Limit import to files we actually need.
                if contest_files is not None:
                    files &= contest_files

                for digest in files:
                    file_ = os.path.join(files_dir, digest)
                    desc = os.path.join(descr_dir, digest)
                    if not self.safe_put_file(file_, desc):
                        logger.critical("Unable to put file `%s' in the database. "
                                        "Aborting. Please remove the contest "
                                        "from the database." % file_)
                        # TODO: remove contest from the database.
                        return False


        if contest_id is not None:
            logger.info("Import finished (contest id: %s)." %
                        ", ".join(str(id_) for id_ in contest_id))
        else:
            logger.info("Import finished.")
        logger.operation = ""

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            rmtree(self.import_dir)

        return True
Example #7
0
    def do_import(self):
        """Run the actual import code.

        """
        logger.operation = "importing contest from %s" % self.import_source
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        logger.info("Reading JSON file...")
        with open(os.path.join(self.import_dir, "contest.json")) as fin:
            contest_json = json.load(fin)
        if self.no_submissions:
            for user in contest_json["users"]:
                user["submissions"] = []
                user["user_tests"] = []

        if not self.only_files:
            with SessionGen(commit=False) as session:

                # Import the contest in JSON format.
                logger.info("Importing the contest from JSON file.")
                contest = Contest.import_from_dict(contest_json)
                session.add(contest)

                session.flush()
                contest_id = contest.id
                contest_files = contest.enumerate_files()
                session.commit()

        if not self.no_files:
            logger.info("Importing files.")
            files_dir = os.path.join(self.import_dir, "files")
            descr_dir = os.path.join(self.import_dir, "descriptions")
            for digest in contest_files:
                file_ = os.path.join(files_dir, digest)
                desc = os.path.join(descr_dir, digest)
                if not os.path.exists(file_) or not os.path.exists(desc):
                    logger.error("Some files needed to the contest "
                                 "are missing in the import directory. "
                                 "The import will continue. Be aware.")
                if not self.safe_put_file(file_, desc):
                    logger.critical("Unable to put file `%s' in the database. "
                                    "Aborting. Please remove the contest "
                                    "from the database." % file_)
                    # TODO: remove contest from the database.
                    return False

        logger.info("Import finished (contest id: %s)." % contest_id)
        logger.operation = ""

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            shutil.rmtree(self.import_dir)

        return True
Example #8
0
    def do_import(self):
        """Run the actual import code.

        """
        logger.operation = "importing contest from %s" % self.import_source
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        if not self.no_files:
            logger.info("Importing files.")
            files_dir = os.path.join(self.import_dir, "files")
            descr_dir = os.path.join(self.import_dir, "descriptions")
            files = set(os.listdir(files_dir))
            for _file in files:
                if not self.safe_put_file(os.path.join(files_dir, _file),
                                          os.path.join(descr_dir, _file)):
                    return False

        if not self.only_files:
            with SessionGen(commit=False) as session:

                # Import the contest in JSON format.
                logger.info("Importing the contest from JSON file.")
                with open(os.path.join(self.import_dir,
                                       "contest.json")) as fin:
                    contest = Contest.import_from_dict(json.load(fin))
                    session.add(contest)

                # Check that no files were missing (only if files were
                # imported).
                if not self.no_files:
                    contest_files = contest.enumerate_files()
                    missing_files = contest_files.difference(files)
                    if len(missing_files) > 0:
                        logger.warning("Some files needed to the contest "
                                       "are missing in the import directory.")

                session.flush()
                contest_id = contest.id
                session.commit()

        logger.info("Import finished (contest id: %s)." % contest_id)
        logger.operation = ""

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            shutil.rmtree(self.import_dir)

        return True
Example #9
0
    def do_import(self):
        """Run the actual import code.

        """
        logger.operation = "importing contest from %s" % self.import_source
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        if not self.only_files:
            with SessionGen(commit=False) as session:

                # Import the contest in JSON format.
                logger.info("Importing the contest from JSON file.")

                with open(os.path.join(self.import_dir,
                                       "contest.json")) as fin:
                    # Throughout all the code we'll assume the input is
                    # correct without actually doing any validations.
                    # Thus, for example, we're not checking that the
                    # decoded object is a dict...
                    self.datas = json.load(fin)

                self.objs = dict()
                for _id, data in self.datas.iteritems():
                    obj = self.import_object(data)
                    self.objs[_id] = obj
                    session.add(obj)

                for _id in self.datas:
                    self.add_relationships(self.datas[_id], self.objs[_id])

                # Mmh... kind of fragile interface
                contest = self.objs["0"]

                # Check that no files were missing (only if files were
                # imported).
                if False and not self.no_files:
                    contest_files = contest.enumerate_files()
                    missing_files = contest_files.difference(files)
                    if len(missing_files) > 0:
                        logger.warning("Some files needed to the contest "
                                       "are missing in the import directory.")

                session.flush()
                contest_id = contest.id
                contest_files = contest.enumerate_files()
                session.commit()

        if not self.no_files:
            logger.info("Importing files.")
            files_dir = os.path.join(self.import_dir, "files")
            descr_dir = os.path.join(self.import_dir, "descriptions")
            for digest in contest_files:
                file_ = os.path.join(files_dir, digest)
                desc = os.path.join(descr_dir, digest)
                if not os.path.exists(file_) or not os.path.exists(desc):
                    logger.error("Some files needed to the contest "
                                 "are missing in the import directory. "
                                 "The import will continue. Be aware.")
                if not self.safe_put_file(file_, desc):
                    logger.critical("Unable to put file `%s' in the database. "
                                    "Aborting. Please remove the contest "
                                    "from the database." % file_)
                    # TODO: remove contest from the database.
                    return False

        logger.info("Import finished (contest id: %s)." % contest_id)
        logger.operation = ""

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            shutil.rmtree(self.import_dir)

        return True
Example #10
0
    def do_import(self):
        """Run the actual import code.

        """
        logger.operation = "importing contest from %s" % self.import_source
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        logger.info("Reading JSON file...")
        with open(os.path.join(self.import_dir, "contest.json")) as fin:
            contest_json = json.load(fin)
        if self.no_submissions:
            for user in contest_json["users"]:
                user["submissions"] = []
                user["user_tests"] = []

        if not self.only_files:
            with SessionGen(commit=False) as session:

                # Import the contest in JSON format.
                logger.info("Importing the contest from JSON file.")
                contest = Contest.import_from_dict(contest_json)
                session.add(contest)

                session.flush()
                contest_id = contest.id
                contest_files = contest.enumerate_files()
                session.commit()

        if not self.no_files:
            logger.info("Importing files.")
            files_dir = os.path.join(self.import_dir, "files")
            descr_dir = os.path.join(self.import_dir, "descriptions")
            for digest in contest_files:
                file_ = os.path.join(files_dir, digest)
                desc = os.path.join(descr_dir, digest)
                print open(desc).read()
                if not os.path.exists(file_) or not os.path.exists(desc):
                    logger.error("Some files needed to the contest "
                                 "are missing in the import directory. "
                                 "The import will continue. Be aware.")
                if not self.safe_put_file(file_, desc):
                    logger.critical("Unable to put file `%s' in the database. "
                                    "Aborting. Please remove the contest "
                                    "from the database." % file_)
                    # TODO: remove contest from the database.
                    return False

        logger.info("Import finished (contest id: %s)." % contest_id)
        logger.operation = ""

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            shutil.rmtree(self.import_dir)

        return True
    def do_import(self):
        """Run the actual import code.

        """
        logger.operation = "importing contest from %s" % self.import_source
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        if not self.no_files:
            logger.info("Importing files.")
            files_dir = os.path.join(self.import_dir, "files")
            descr_dir = os.path.join(self.import_dir, "descriptions")
            files = set(os.listdir(files_dir))
            for _file in files:
                if not self.safe_put_file(os.path.join(files_dir, _file),
                                          os.path.join(descr_dir, _file)):
                    return False

        if not self.only_files:
            with SessionGen(commit=False) as session:

                # Import the contest in JSON format.
                logger.info("Importing the contest from JSON file.")
                with open(os.path.join(self.import_dir,
                                       "contest.json")) as fin:
                    contest = Contest.import_from_dict(json.load(fin))
                    session.add(contest)

                # Check that no files were missing (only if files were
                # imported).
                if not self.no_files:
                    contest_files = contest.enumerate_files()
                    missing_files = contest_files.difference(files)
                    if len(missing_files) > 0:
                        logger.warning("Some files needed to the contest "
                                       "are missing in the import directory.")

                session.flush()
                contest_id = contest.id
                session.commit()

        logger.info("Import finished (contest id: %s)." % contest_id)
        logger.operation = ""

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            shutil.rmtree(self.import_dir)

        return True
Example #12
0
    def do_import(self):
        """Run the actual import code."""
        logger.operation = "importing contest from %s" % self.import_source
        logger.info("Starting import.")

        if not os.path.isdir(self.import_source):
            if self.import_source.endswith(".zip"):
                archive = zipfile.ZipFile(self.import_source, "r")
                file_names = archive.infolist()

                self.import_dir = tempfile.mkdtemp()
                archive.extractall(self.import_dir)
            elif self.import_source.endswith(".tar.gz") \
                     or self.import_source.endswith(".tgz") \
                     or self.import_source.endswith(".tar.bz2") \
                     or self.import_source.endswith(".tbz2") \
                     or self.import_source.endswith(".tar"):
                archive = tarfile.open(name=self.import_source)
                file_names = archive.getnames()
            else:
                logger.critical("Unable to import from %s." %
                                self.import_source)
                return False

            root = find_root_of_archive(file_names)
            if root is None:
                logger.critical("Cannot find a root directory in %s." %
                                self.import_source)
                return False

            self.import_dir = tempfile.mkdtemp()
            archive.extractall(self.import_dir)
            self.import_dir = os.path.join(self.import_dir, root)

        if self.drop:
            logger.info("Dropping and recreating the database.")
            try:
                metadata.drop_all()
            except sqlalchemy.exc.OperationalError as error:
                logger.critical("Unable to access DB.\n%r" % error)
                return False
        try:
            metadata.create_all()
        except sqlalchemy.exc.OperationalError as error:
            logger.critical("Unable to access DB.\n%r" % error)
            return False

        with SessionGen(commit=False) as session:

            # Import the contest in JSON format.
            if self.load_model:
                logger.info("Importing the contest from a JSON file.")

                with io.open(os.path.join(self.import_dir,
                                          "contest.json"), "rb") as fin:
                    # TODO - Throughout all the code we'll assume the
                    # input is correct without actually doing any
                    # validations.  Thus, for example, we're not
                    # checking that the decoded object is a dict...
                    self.datas = json.load(fin, encoding="utf-8")

                self.objs = dict()
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.objs[id_] = self.import_object(data)
                for id_, data in self.datas.iteritems():
                    if not id_.startswith("_"):
                        self.add_relationships(data, self.objs[id_])

                for k, v in list(self.objs.iteritems()):

                    # Skip submissions if requested
                    if self.skip_submissions and isinstance(v, Submission):
                        del self.objs[k]

                    # Skip user_tests if requested
                    if self.skip_user_tests and isinstance(v, UserTest):
                        del self.objs[k]

                contest_id = list()
                contest_files = set()

                # Add each base object and all its dependencies
                for id_ in self.datas["_objects"]:
                    contest = self.objs[id_]

                    # We explictly add only the contest since all child
                    # objects will be automatically added by cascade.
                    # Adding each object individually would also add
                    # orphaned objects like the ones that depended on
                    # submissions or user_tests that we (possibly)
                    # removed above.
                    session.add(contest)
                    session.flush()

                    contest_id += [contest.id]
                    contest_files |= contest.enumerate_files(
                        self.skip_submissions, self.skip_user_tests, self.light)

                session.commit()
            else:
                contest_id = None
                contest_files = None

            # Import files.
            if self.load_files:
                logger.info("Importing files.")

                files_dir = os.path.join(self.import_dir, "files")
                descr_dir = os.path.join(self.import_dir, "descriptions")

                files = set(os.listdir(files_dir))
                descr = set(os.listdir(descr_dir))

                if not descr <= files:
                    logger.warning("Some files do not have an associated "
                                   "description.")
                if not files <= descr:
                    logger.warning("Some descriptions do not have an "
                                   "associated file.")

                if not (contest_files is None or files <= contest_files):
                    # FIXME Check if it's because this is a light import
                    # or because we're skipping submissions or user_tests
                    logger.warning("The dump contains some files that are "
                                   "not needed by the contest.")
                if not (contest_files is None or contest_files <= files):
                    # The reason for this could be that it was a light
                    # export that's not being reimported as such.
                    logger.warning("The contest needs some files that are "
                                   "not contained in the dump.")

                # Limit import to files we actually need.
                if contest_files is not None:
                    files &= contest_files

                for digest in files:
                    file_ = os.path.join(files_dir, digest)
                    desc = os.path.join(descr_dir, digest)
                    if not self.safe_put_file(file_, desc):
                        logger.critical("Unable to put file `%s' in the database. "
                                        "Aborting. Please remove the contest "
                                        "from the database." % file_)
                        # TODO: remove contest from the database.
                        return False


        if contest_id is not None:
            logger.info("Import finished (contest id: %s)." %
                        ", ".join(str(id_) for id_ in contest_id))
        else:
            logger.info("Import finished.")
        logger.operation = ""

        # If we extracted an archive, we remove it.
        if self.import_dir != self.import_source:
            shutil.rmtree(self.import_dir)

        return True