import os import numpy as np import pickle import torch from contextualized_topic_models.datasets.dataset import CTMDatasetTopReg from contextualized_topic_models.utils.data_preparation import TextHandler, labels_from_file handler = TextHandler( "contextualized_topic_models/data/wiki/wiki_train_en_prep.txt") handler.prepare() train_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_train_en_unprep.txt', \ '../sentence-transformers/sentence_transformers/output/training_wiki_topics_4_xlm-roberta-base-2020-10-24_13-38-14/') labels, nb_labels = labels_from_file( 'contextualized_topic_models/data/wiki/topic_full_misc.json') training_dataset = CTMDatasetTopReg(handler.bow, train_bert, handler.idx2token, labels) num_topics = 100 # the nb_labels argument turns this from a CTM into a TCCTM ctm = CTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=60, hidden_sizes=(100, ), inference_type="contextual", n_components=num_topics, num_data_loader_workers=0, nb_labels=nb_labels) ctm.fit(training_dataset) ctm.save("models/wiki/wiki_xlmr_en_topics_4_topreg")
handler.prepare() #train_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_train_en_unprep.txt', 'distiluse-base-multilingual-cased') # train_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_train_en_unprep.txt', 'xlm-r-100langs-bert-base-nli-mean-tokens') train_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_train_en_unprep.txt', \ '../sentence-transformers/sentence_transformers/output/training_nli_wiki-xlmr-2020-12-15_00-20-18') training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token) num_topics = 100 #ctm = CTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=60, hidden_sizes=(100,), # inference_type="contextual", n_components=num_topics, num_data_loader_workers=0) # ctm = CTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=60, hidden_sizes=(100,), # inference_type="contextual", n_components=num_topics, num_data_loader_workers=0) ctm = CTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=60, hidden_sizes=(100, ), inference_type="contextual", n_components=num_topics, num_data_loader_workers=0) ctm.fit(training_dataset) ctm.save("models/wiki/wiki_xlmr_en_nli_ct") # filehandler = open("iqos_en.ctm", 'wb') # torch.save(ctm, "wiki_en_xlmr_topicsiqos_1.ctm", pickle_protocol=4) # with open("contextualized_topic_models/data/wiki/wiki_train_en_prep_sub.txt", "r") as en: # texts = [doc.split() for doc in en.read().splitlines()] # npmi = CoherenceNPMI(texts=texts, topics=ctm.get_topic_lists(10)) # print(npmi.score())