示例#1
0
 def Import(self, session: sqlutil.Session,
            config: corpus_pb2.Corpus) -> None:
     with self.GetContentFileRoot(config) as contentfile_root:
         relpaths = set(self.GetImportRelpaths(contentfile_root))
         done = set([
             x[0]
             for x in session.query(PreprocessedContentFile.input_relpath)
         ])
         todo = relpaths - done
         logging.info('Preprocessing %s of %s content files',
                      humanize.intcomma(len(todo)),
                      humanize.intcomma(len(relpaths)))
         jobs = [
             internal_pb2.PreprocessorWorker(
                 contentfile_root=str(contentfile_root),
                 relpath=t,
                 preprocessors=config.preprocessor) for t in todo
         ]
         pool = multiprocessing.Pool()
         bar = progressbar.ProgressBar(max_value=len(jobs))
         last_commit = time.time()
         wall_time_start = time.time()
         for preprocessed_cf in bar(
                 pool.imap_unordered(PreprocessorWorker, jobs)):
             wall_time_end = time.time()
             preprocessed_cf.wall_time_ms = (int(
                 (wall_time_end - wall_time_start) * 1000))
             wall_time_start = wall_time_end
             session.add(preprocessed_cf)
             if wall_time_end - last_commit > 10:
                 session.commit()
                 last_commit = wall_time_end
示例#2
0
    def Import(self, session: sqlutil.Session,
               preprocessed_db: preprocessed.PreprocessedContentFiles,
               atomizer: atomizers.AtomizerBase,
               contentfile_separator: str) -> None:
        with preprocessed_db.Session() as p_session:
            query = p_session.query(
                preprocessed.PreprocessedContentFile).filter(
                    preprocessed.PreprocessedContentFile.
                    preprocessing_succeeded == True,
                    ~preprocessed.PreprocessedContentFile.id.in_(
                        session.query(EncodedContentFile.id).all()))
            jobs = [
                internal_pb2.EncoderWorker(
                    id=x.id,
                    text=x.text,
                    contentfile_separator=contentfile_separator,
                    pickled_atomizer=pickle.dumps(atomizer)) for x in query
            ]
            if not jobs:
                raise errors.EmptyCorpusException(
                    "Pre-processed corpus contains no files: "
                    f"'{preprocessed_db.url}'")

            logging.info(
                'Encoding %s of %s preprocessed files',
                humanize.intcomma(query.count()),
                humanize.intcomma(
                    p_session.query(
                        preprocessed.PreprocessedContentFile).filter(
                            preprocessed.PreprocessedContentFile.
                            preprocessing_succeeded == True).count()))
            pool = multiprocessing.Pool()
            bar = progressbar.ProgressBar(max_value=len(jobs))
            last_commit = time.time()
            wall_time_start = time.time()
            for encoded_cf in bar(pool.imap_unordered(EncoderWorker, jobs)):
                wall_time_end = time.time()
                # TODO(cec): Remove the if check once EncoderWorker no longer returns
                # None on atomizer encode error.
                if encoded_cf:
                    encoded_cf.wall_time_ms = int(
                        (wall_time_end - wall_time_start) * 1000)
                    session.add(encoded_cf)
                wall_time_start = wall_time_end
                if wall_time_end - last_commit > 10:
                    session.commit()
                    last_commit = wall_time_end