def test_training_with_saved_data(data_dir): handler = TextHandler(data_dir + "gnews/GoogleNews.txt") handler.prepare() # create vocabulary and training data # load BERT data with open(data_dir + "gnews/bert_embeddings_gnews", "rb") as filino: training_bert = pickle.load(filino) training_dataset = CTMDataset(handler.bow, training_bert, handler.idx2token) ctm = CTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=1, inference_type="combined", n_components=5) ctm.fit(training_dataset) # run the model print(ctm.get_topics(2)) ctm.get_doc_topic_distribution(training_dataset) assert isclose(np.sum(ctm.get_topic_word_distribution()[0]), 1, rel_tol=1e-5, abs_tol=0.0)
def test_training_all_classes_ctm(data_dir): handler = TextHandler(data_dir + "sample_text_document") handler.prepare() # create vocabulary and training data train_bert = bert_embeddings_from_file( data_dir + 'sample_text_document', "distiluse-base-multilingual-cased") training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token) ctm = CTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, inference_type="combined", n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) ctm = ZeroShotTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) ctm = CombinedTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) with open(data_dir + 'sample_text_document') as filino: data = filino.readlines() handler = TextHandler(sentences=data) handler.prepare() # create vocabulary and training data train_bert = bert_embeddings_from_list( data, "distiluse-base-multilingual-cased") training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token) ctm = CTM(input_size=len(handler.vocab), bert_input_size=512, num_epochs=1, inference_type="combined", n_components=5) ctm.fit(training_dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 thetas = ctm.get_doc_topic_distribution(training_dataset) assert len(thetas) == len(train_bert) qt = QuickText("distiluse-base-multilingual-cased", text_for_bow=data, text_for_bert=data) dataset = qt.load_dataset() ctm = ZeroShotTM(input_size=len(qt.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5 qt_from_conf = QuickText("distiluse-base-multilingual-cased", None, None) qt_from_conf.load_configuration(qt.bow, qt.data_bert, qt.vocab, qt.idx2token) dataset = qt_from_conf.load_dataset() ctm = ZeroShotTM(input_size=len(qt.vocab), bert_input_size=512, num_epochs=1, n_components=5) ctm.fit(dataset) # run the model topics = ctm.get_topic_lists(2) assert len(topics) == 5