Пример #1
0
def GetVocabFromMetaTable(session: sqlutil.Session) -> typing.Dict[str, int]:
  """Read a vocabulary dictionary from the 'Meta' table of a database."""
  q = session.query(encoded.Meta.value).filter(encoded.Meta.key == 'vocab_size')
  if not q.first():
    return {}

  vocab_size = int(q.one()[0])
  q = session.query(encoded.Meta.value)
  return {
      q.filter(encoded.Meta.key == f'vocab_{i}').one()[0]: i
      for i in range(vocab_size)
  }
Пример #2
0
 def Import(self, session: sqlutil.Session,
            config: corpus_pb2.Corpus) -> None:
     with self.GetContentFileRoot(config) as contentfile_root:
         relpaths = set(self.GetImportRelpaths(contentfile_root))
         done = set([
             x[0]
             for x in session.query(PreprocessedContentFile.input_relpath)
         ])
         todo = relpaths - done
         logging.info('Preprocessing %s of %s content files',
                      humanize.intcomma(len(todo)),
                      humanize.intcomma(len(relpaths)))
         jobs = [
             internal_pb2.PreprocessorWorker(
                 contentfile_root=str(contentfile_root),
                 relpath=t,
                 preprocessors=config.preprocessor) for t in todo
         ]
         pool = multiprocessing.Pool()
         bar = progressbar.ProgressBar(max_value=len(jobs))
         last_commit = time.time()
         wall_time_start = time.time()
         for preprocessed_cf in bar(
                 pool.imap_unordered(PreprocessorWorker, jobs)):
             wall_time_end = time.time()
             preprocessed_cf.wall_time_ms = (int(
                 (wall_time_end - wall_time_start) * 1000))
             wall_time_start = wall_time_end
             session.add(preprocessed_cf)
             if wall_time_end - last_commit > 10:
                 session.commit()
                 last_commit = wall_time_end
Пример #3
0
def StoreVocabInMetaTable(session: sqlutil.Session,
                          vocabulary: typing.Dict[str, int]):
  """Store a vocabulary dictionary in the 'Meta' table of a database."""
  q = session.query(encoded.Meta).filter(encoded.Meta.key.like('vocab_%'))
  q.delete(synchronize_session=False)

  session.add(encoded.Meta(key='vocab_size', value=str(len(vocabulary))))
  session.add_all(
      [encoded.Meta(key=f'vocab_{v}', value=k) for k, v in vocabulary.items()])
Пример #4
0
    def Import(self, session: sqlutil.Session,
               preprocessed_db: preprocessed.PreprocessedContentFiles,
               atomizer: atomizers.AtomizerBase,
               contentfile_separator: str) -> None:
        with preprocessed_db.Session() as p_session:
            query = p_session.query(
                preprocessed.PreprocessedContentFile).filter(
                    preprocessed.PreprocessedContentFile.
                    preprocessing_succeeded == True,
                    ~preprocessed.PreprocessedContentFile.id.in_(
                        session.query(EncodedContentFile.id).all()))
            jobs = [
                internal_pb2.EncoderWorker(
                    id=x.id,
                    text=x.text,
                    contentfile_separator=contentfile_separator,
                    pickled_atomizer=pickle.dumps(atomizer)) for x in query
            ]
            if not jobs:
                raise errors.EmptyCorpusException(
                    "Pre-processed corpus contains no files: "
                    f"'{preprocessed_db.url}'")

            logging.info(
                'Encoding %s of %s preprocessed files',
                humanize.intcomma(query.count()),
                humanize.intcomma(
                    p_session.query(
                        preprocessed.PreprocessedContentFile).filter(
                            preprocessed.PreprocessedContentFile.
                            preprocessing_succeeded == True).count()))
            pool = multiprocessing.Pool()
            bar = progressbar.ProgressBar(max_value=len(jobs))
            last_commit = time.time()
            wall_time_start = time.time()
            for encoded_cf in bar(pool.imap_unordered(EncoderWorker, jobs)):
                wall_time_end = time.time()
                # TODO(cec): Remove the if check once EncoderWorker no longer returns
                # None on atomizer encode error.
                if encoded_cf:
                    encoded_cf.wall_time_ms = int(
                        (wall_time_end - wall_time_start) * 1000)
                    session.add(encoded_cf)
                wall_time_start = wall_time_end
                if wall_time_end - last_commit > 10:
                    session.commit()
                    last_commit = wall_time_end
Пример #5
0
 def IsDone(self, session: sqlutil.Session):
     if session.query(Meta).filter(Meta.key == 'done').first():
         return True
     else:
         return False