def test_validation_set(data_dir): with open(data_dir + '/gnews/GoogleNews.txt') as filino: data = filino.readlines() tp = TopicModelDataPreparation("distiluse-base-multilingual-cased") training_dataset = tp.create_training_set(data[:100], data[:100]) validation_dataset = tp.create_validation_set(data[100:105], data[100:105]) ctm = ZeroShotTM(input_size=len(tp.vocab), bert_input_size=512, num_epochs=100, n_components=5) ctm.fit(training_dataset, validation_dataset=validation_dataset, patience=5, save_dir=data_dir + 'test_checkpoint') assert os.path.exists(data_dir + "test_checkpoint")
def test_training_all_classes_ctm(data_dir): handler = TextHandler(data_dir + "sample_text_document") handler.prepare() # create vocabulary and training data train_bert = bert_embeddings_from_file( data_dir + 'sample_text_document', "distiluse-base-multilingual-cased") training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token) ctm = CTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, inference_type="combined", n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) ctm = ZeroShotTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) ctm = CombinedTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) with open(data_dir + 'sample_text_document') as filino: data = filino.readlines() handler = TextHandler(sentences=data) handler.prepare() # create vocabulary and training data train_bert = bert_embeddings_from_list( data, "distiluse-base-multilingual-cased") training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token) ctm = CTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, inference_type="combined", n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) qt = QuickText("distiluse-base-multilingual-cased", text_for_bow=data, text_for_bert=data) dataset = qt.load_dataset() ctm = ZeroShotTM(input_size=len(qt.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 qt_from_conf = QuickText("distiluse-base-multilingual-cased", None, None) qt_from_conf.load_configuration(qt.bow, qt.data_bert, qt.vocab, qt.idx2token) dataset = qt_from_conf.load_dataset() ctm = ZeroShotTM(input_size=len(qt.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5
sp = WhiteSpacePreprocessing(documents, stopwords_language=stopwords_lang) preprocessed_documents, unpreprocessed_corpus, vocab = sp.preprocess() # get BERT embeddings of docs print("Get BERT encoding") qt = QuickText("distiluse-base-multilingual-cased", text_for_bow=preprocessed_documents, text_for_bert=unpreprocessed_corpus) training_dataset = qt.load_dataset() # train CTM print("Start training CTM") ctm = ZeroShotTM(input_size=len(qt.vocab), bert_input_size=512, n_components=args.num_topics, num_epochs=args.epochs) ctm.fit(training_dataset) # check topics found print("Done training CTM!") topics = ctm.get_topic_lists(20) for i, topic in enumerate(topics): print("Topic", i, ":", ', '.join(topic)) # save trained model model_file = "results/cldr/ctm_" + args.train_data.split( "/")[-1][:-4] + "_" + str(args.num_topics) + "topics_" + str( args.epochs) + "epochs.pkl" with open(model_file, 'wb') as f:
data_preparation = joblib.load(args.data_preparation) prepared_training_dataset = joblib.load(args.prepared_training_dataset) logging.info("CTM training resources loaded") df = create_model_dataframe() logging.info(f'Vocab length: {len(data_preparation.vocab)}') input_size = 768 if args.lang == "en" else 512 for k in topics: start = time.time() ctm = CombinedTM(input_size=len(data_preparation.vocab), bert_input_size=input_size, n_components=k) \ if args.inference == "combined" \ else ZeroShotTM(input_size=len(data_preparation.vocab), bert_input_size=input_size, n_components=k) ctm.fit(prepared_training_dataset) end = time.time() topic_words = ctm.get_topic_lists(20) unnormalized_topic_word_dist = ctm.get_topic_word_matrix() softmax = torch.nn.Softmax(dim=1) #Normalizes the matrix topic_word_dist = softmax(torch.from_numpy(unnormalized_topic_word_dist)) topics_with_word_probs = get_topics_with_word_probabilities( ctm.train_data.idx2token, topic_word_dist)