def main(): global LOGG # --- arguments --- (dataset, version, cb_logger, params, nbs_topics, epochs, cores, cache_in_memory, use_callbacks, corpus_type, args) = parse_args() model_class = 'LDAmodel' _split = "_split" if use_callbacks else "" # --- logging --- logger = init_logging( name=f'LDA_{dataset}_{version}_{corpus_type}{_split}_ep{epochs}', basic=False, to_stdout=True, to_file=True) LOGG = logger.info log_args(logger, args) # --- load texts --- if use_callbacks: texts = load(dataset, version, 'texts', logger=logger) else: texts = [] # --- load dict --- dictionary = load(dataset, version, corpus_type, 'dict', logger=logger) # --- load corpus --- corpus = load(dataset, version, corpus_type, 'corpus', logger=logger) if cache_in_memory: LOGG('Reading corpus into RAM') corpus = list(corpus) if use_callbacks: train, test = split_corpus(corpus) else: train, test = corpus, [] LOGG(f'size of... train_set={len(train)}, test_set={len(test)}') # --- enable visdom --- vis = None if cb_logger == 'visdom': try: import visdom vis = visdom.Visdom() except Exception as e: LOGG(e) cb_logger = 'shell' # --- train --- topn = 20 columns = [f'term{x}' for x in range(topn)] + [f'weight{x}' for x in range(topn)] metrics = [] for param in params: env_id = f"{dataset}-{model_class}" for nbtopics in nbs_topics: gc.collect() callbacks = init_callbacks(dataset=dataset, callback_logger=cb_logger, documents=texts, training_corpus=train, test_corpus=test, processes=cores, version=version, param=param, nbtopics=nbtopics, tfidf=corpus_type) if not use_callbacks: callbacks = callbacks[-1:] else: LOGG('Initializing Callbacks') kwargs = get_parameterset(train, dictionary, callbacks=callbacks, nbtopics=nbtopics, parametrization=param, epochs=epochs) LOGG( f'Running {model_class} {corpus_type} "{param}{_split}" with {nbtopics} topics' ) model = LdaModel(**kwargs) gc.collect() model_dir = join(LDA_PATH, version, corpus_type, f'{param}{_split}') model_path = join( model_dir, f'{dataset}_LDAmodel_{param}{_split}_{nbtopics}_ep{epochs}') if not exists(model_dir): makedirs(model_dir) # --- save topics --- topics = model.show_topics(num_topics=-1, num_words=topn, formatted=False) topics = [list(chain(*zip(*topic[1]))) for topic in topics] topics = pd.DataFrame(topics, columns=columns) LOGG(f'Saving topics to {model_path}.csv') topics.to_csv(f'{model_path}.csv') # --- save metrics --- current_metrics = model.metrics metrics.append(('env_id', current_metrics)) with open(f'{model_path}_metrics.json', 'w') as fp: serializable_metrics = {} for k, v in current_metrics.items(): if k == dataset: continue if isinstance(v[0], np.ndarray): serializable_metrics[k] = [x.tolist() for x in v] else: serializable_metrics[k] = [float(x) for x in v] LOGG(f'Saving metrics to {model_path}_metrics.json') json.dump(serializable_metrics, fp) # --- save model --- LOGG(f'Saving LDAmodel to {model_path}') model.callbacks = None model.save(model_path) # --- save visdom environment --- if vis is not None: vis.save([env_id]) gc.collect() # --- done --- LOGG(f'\n' f'----- end -----\n' f'----- {dataset.upper()} -----\n' f'{"#" * 50}\n')