Esempio n. 1
0
def test_validation_set(data_dir):

    with open(data_dir + '/gnews/GoogleNews.txt') as filino:
        data = filino.readlines()

    tp = TopicModelDataPreparation("distiluse-base-multilingual-cased")

    training_dataset = tp.create_training_set(data[:100], data[:100])
    validation_dataset = tp.create_validation_set(data[100:105], data[100:105])

    ctm = ZeroShotTM(input_size=len(tp.vocab),
                     bert_input_size=512,
                     num_epochs=100,
                     n_components=5)
    ctm.fit(training_dataset,
            validation_dataset=validation_dataset,
            patience=5,
            save_dir=data_dir + 'test_checkpoint')

    assert os.path.exists(data_dir + "test_checkpoint")
def test_training_all_classes_ctm(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = ZeroShotTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = CombinedTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    with open(data_dir + 'sample_text_document') as filino:
        data = filino.readlines()

    handler = TextHandler(sentences=data)
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_list(
        data, "distiluse-base-multilingual-cased")
    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)

    assert len(topics) == 5
    thetas = ctm.get_doc_topic_distribution(training_dataset)

    assert len(thetas) == len(train_bert)

    qt = QuickText("distiluse-base-multilingual-cased",
                   text_for_bow=data,
                   text_for_bert=data)

    dataset = qt.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    qt_from_conf = QuickText("distiluse-base-multilingual-cased", None, None)
    qt_from_conf.load_configuration(qt.bow, qt.data_bert, qt.vocab,
                                    qt.idx2token)
    dataset = qt_from_conf.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5
Esempio n. 3
0
sp = WhiteSpacePreprocessing(documents, stopwords_language=stopwords_lang)
preprocessed_documents, unpreprocessed_corpus, vocab = sp.preprocess()

# get BERT embeddings of docs
print("Get BERT encoding")
qt = QuickText("distiluse-base-multilingual-cased",
               text_for_bow=preprocessed_documents,
               text_for_bert=unpreprocessed_corpus)

training_dataset = qt.load_dataset()

# train CTM
print("Start training CTM")
ctm = ZeroShotTM(input_size=len(qt.vocab),
                 bert_input_size=512,
                 n_components=args.num_topics,
                 num_epochs=args.epochs)
ctm.fit(training_dataset)

# check topics found
print("Done training CTM!")
topics = ctm.get_topic_lists(20)
for i, topic in enumerate(topics):
    print("Topic", i, ":", ', '.join(topic))

# save trained model

model_file = "results/cldr/ctm_" + args.train_data.split(
    "/")[-1][:-4] + "_" + str(args.num_topics) + "topics_" + str(
        args.epochs) + "epochs.pkl"
with open(model_file, 'wb') as f:
Esempio n. 4
0
data_preparation = joblib.load(args.data_preparation)
prepared_training_dataset = joblib.load(args.prepared_training_dataset)
logging.info("CTM training resources loaded")

df = create_model_dataframe()

logging.info(f'Vocab length: {len(data_preparation.vocab)}')

input_size = 768 if args.lang == "en" else 512

for k in topics:
    start = time.time()

    ctm = CombinedTM(input_size=len(data_preparation.vocab), bert_input_size=input_size, n_components=k) \
        if args.inference == "combined" \
            else ZeroShotTM(input_size=len(data_preparation.vocab), bert_input_size=input_size, n_components=k)

    ctm.fit(prepared_training_dataset)

    end = time.time()

    topic_words = ctm.get_topic_lists(20)
    unnormalized_topic_word_dist = ctm.get_topic_word_matrix()

    softmax = torch.nn.Softmax(dim=1)
    #Normalizes the matrix
    topic_word_dist = softmax(torch.from_numpy(unnormalized_topic_word_dist))

    topics_with_word_probs = get_topics_with_word_probabilities(
        ctm.train_data.idx2token, topic_word_dist)