def GetVocabFromMetaTable(session: sqlutil.Session) -> typing.Dict[str, int]: """Read a vocabulary dictionary from the 'Meta' table of a database.""" q = session.query(encoded.Meta.value).filter(encoded.Meta.key == 'vocab_size') if not q.first(): return {} vocab_size = int(q.one()[0]) q = session.query(encoded.Meta.value) return { q.filter(encoded.Meta.key == f'vocab_{i}').one()[0]: i for i in range(vocab_size) }
def Import(self, session: sqlutil.Session, config: corpus_pb2.Corpus) -> None: with self.GetContentFileRoot(config) as contentfile_root: relpaths = set(self.GetImportRelpaths(contentfile_root)) done = set([ x[0] for x in session.query(PreprocessedContentFile.input_relpath) ]) todo = relpaths - done logging.info('Preprocessing %s of %s content files', humanize.intcomma(len(todo)), humanize.intcomma(len(relpaths))) jobs = [ internal_pb2.PreprocessorWorker( contentfile_root=str(contentfile_root), relpath=t, preprocessors=config.preprocessor) for t in todo ] pool = multiprocessing.Pool() bar = progressbar.ProgressBar(max_value=len(jobs)) last_commit = time.time() wall_time_start = time.time() for preprocessed_cf in bar( pool.imap_unordered(PreprocessorWorker, jobs)): wall_time_end = time.time() preprocessed_cf.wall_time_ms = (int( (wall_time_end - wall_time_start) * 1000)) wall_time_start = wall_time_end session.add(preprocessed_cf) if wall_time_end - last_commit > 10: session.commit() last_commit = wall_time_end
def StoreVocabInMetaTable(session: sqlutil.Session, vocabulary: typing.Dict[str, int]): """Store a vocabulary dictionary in the 'Meta' table of a database.""" q = session.query(encoded.Meta).filter(encoded.Meta.key.like('vocab_%')) q.delete(synchronize_session=False) session.add(encoded.Meta(key='vocab_size', value=str(len(vocabulary)))) session.add_all( [encoded.Meta(key=f'vocab_{v}', value=k) for k, v in vocabulary.items()])
def Import(self, session: sqlutil.Session, preprocessed_db: preprocessed.PreprocessedContentFiles, atomizer: atomizers.AtomizerBase, contentfile_separator: str) -> None: with preprocessed_db.Session() as p_session: query = p_session.query( preprocessed.PreprocessedContentFile).filter( preprocessed.PreprocessedContentFile. preprocessing_succeeded == True, ~preprocessed.PreprocessedContentFile.id.in_( session.query(EncodedContentFile.id).all())) jobs = [ internal_pb2.EncoderWorker( id=x.id, text=x.text, contentfile_separator=contentfile_separator, pickled_atomizer=pickle.dumps(atomizer)) for x in query ] if not jobs: raise errors.EmptyCorpusException( "Pre-processed corpus contains no files: " f"'{preprocessed_db.url}'") logging.info( 'Encoding %s of %s preprocessed files', humanize.intcomma(query.count()), humanize.intcomma( p_session.query( preprocessed.PreprocessedContentFile).filter( preprocessed.PreprocessedContentFile. preprocessing_succeeded == True).count())) pool = multiprocessing.Pool() bar = progressbar.ProgressBar(max_value=len(jobs)) last_commit = time.time() wall_time_start = time.time() for encoded_cf in bar(pool.imap_unordered(EncoderWorker, jobs)): wall_time_end = time.time() # TODO(cec): Remove the if check once EncoderWorker no longer returns # None on atomizer encode error. if encoded_cf: encoded_cf.wall_time_ms = int( (wall_time_end - wall_time_start) * 1000) session.add(encoded_cf) wall_time_start = wall_time_end if wall_time_end - last_commit > 10: session.commit() last_commit = wall_time_end
def IsDone(self, session: sqlutil.Session): if session.query(Meta).filter(Meta.key == 'done').first(): return True else: return False