def test_training_with_saved_data(data_dir):
    handler = TextHandler(data_dir + "gnews/GoogleNews.txt")
    handler.prepare()  # create vocabulary and training data

    # load BERT data
    with open(data_dir + "gnews/bert_embeddings_gnews", "rb") as filino:
        training_bert = pickle.load(filino)

    training_dataset = CTMDataset(handler.bow, training_bert,
                                  handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=768,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model

    print(ctm.get_topics(2))

    ctm.get_doc_topic_distribution(training_dataset)
    assert isclose(np.sum(ctm.get_topic_word_distribution()[0]),
                   1,
                   rel_tol=1e-5,
                   abs_tol=0.0)
def test_shape_checks(data_dir):
    handler = TextHandler(data_dir + "gnews/GoogleNews.txt")
    handler.prepare()  # create vocabulary and training data

    # load BERT data
    with open(data_dir + "gnews/bert_embeddings_gnews", "rb") as filino:
        training_bert = pickle.load(filino)

    assert len(training_bert) == len(handler.bow.todense())
def test_embeddings_from_scratch(data_dir):

    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    assert np.array_equal(
        handler.bow.todense(),
        np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2]]))
def test_embeddings_from_scratch(data_dir):

    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    assert np.array_equal(
        handler.bow.todense(),
        np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                  [1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2]]))

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    assert len(train_bert) == len(handler.bow.todense())
def test_training(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_thetas(training_dataset)
    assert len(thetas) == len(train_bert)
import torch
import sys
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.evaluation.measures import CoherenceNPMI
from contextualized_topic_models.utils.data_preparation import TextHandler


def show_topics(topic_list):
    for idx, topic_tokens in enumerate(topic_list):
        print(idx)
        print(' '.join(topic_tokens))


if len(sys.argv) < 2:
    raise Exception("Usage: python {} {}".format(sys.argv[0], "<model_file>"))

handler = TextHandler(
    "contextualized_topic_models/data/wiki/wiki_train_en_prep.txt")
handler.prepare()
ctm = CTM(input_size=len(handler.vocab), inference_type="contextual", bert_input_size=768, \
    num_data_loader_workers=1)
ctm.load(sys.argv[1], sys.argv[2])
with open("contextualized_topic_models/data/wiki/wiki_train_en_prep.txt",
          "r") as en:
    texts = [doc.split() for doc in en.read().splitlines()]

# obtain NPMI coherences on the topic modeled documents
show_topics(ctm.get_topic_lists(10))
npmi = CoherenceNPMI(texts=texts, topics=ctm.get_topic_lists(10))
print(npmi.score())
def test_training_all_classes_ctm(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = ZeroShotTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = CombinedTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    with open(data_dir + 'sample_text_document') as filino:
        data = filino.readlines()

    handler = TextHandler(sentences=data)
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_list(
        data, "distiluse-base-multilingual-cased")
    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)

    assert len(topics) == 5
    thetas = ctm.get_doc_topic_distribution(training_dataset)

    assert len(thetas) == len(train_bert)

    qt = QuickText("distiluse-base-multilingual-cased",
                   text_for_bow=data,
                   text_for_bert=data)

    dataset = qt.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    qt_from_conf = QuickText("distiluse-base-multilingual-cased", None, None)
    qt_from_conf.load_configuration(qt.bow, qt.data_bert, qt.vocab,
                                    qt.idx2token)
    dataset = qt_from_conf.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5
Exemplo n.º 8
0
from collections import Counter


def show_topics(topic_list):
    for idx, topic_tokens in enumerate(topic_list):
        print(idx)
        print(' '.join(topic_tokens))


if len(sys.argv) < 4:
    raise Exception("Usage: python {} {} {} {}".format(sys.argv[0],
                                                       "<ctm_model>",
                                                       "<checkpoint>",
                                                       "<sbert_model>"))

handler_en = TextHandler(
    "contextualized_topic_models/data/wiki/wiki_test_en_prep_sub.txt")
# handler = TextHandler("contextualized_topic_models/data/wiki/iqos_corpus_prep_en.txt")
handler_en.prepare()

testing_bert_en = bert_embeddings_from_file(
    "contextualized_topic_models/data/wiki/wiki_test_en_unprep_sub.txt",
    sys.argv[3])
testing_dataset_en = CTMDataset(handler_en.bow, testing_bert_en,
                                handler_en.idx2token)

ctm = CTM(input_size=len(handler_en.vocab),
          inference_type="contextual",
          bert_input_size=768)
# ctm = torch.load(sys.argv[1], map_location="cpu")
ctm.load(sys.argv[1], sys.argv[2])