def test_training_with_saved_data(data_dir):
    handler = TextHandler(data_dir + "gnews/GoogleNews.txt")
    handler.prepare()  # create vocabulary and training data

    # load BERT data
    with open(data_dir + "gnews/bert_embeddings_gnews", "rb") as filino:
        training_bert = pickle.load(filino)

    training_dataset = CTMDataset(handler.bow, training_bert,
                                  handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=768,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model

    print(ctm.get_topics(2))

    ctm.get_doc_topic_distribution(training_dataset)
    assert isclose(np.sum(ctm.get_topic_word_distribution()[0]),
                   1,
                   rel_tol=1e-5,
                   abs_tol=0.0)
예제 #2
0
 def __init__(self,
              train_size: int = TRAIN_SIZE,
              n_topics: int = NUM_TOPICS,
              batch_size: int = BATCH_SIZE,
              shuffle=True):
     data_loader = TwitterDataset("TRAIN",
                                  train_size,
                                  batch_size=batch_size,
                                  shuffle=shuffle)
     data_loader.init()
     self.data_loader = data_loader
     self.ctm = CTM(input_size=len(self.data_loader.vocab),
                    bert_input_size=512,
                    num_epochs=20,
                    batch_size=MINI_BATCH_SIZE,
                    inference_type="contextual",
                    n_components=n_topics,
                    reduce_on_plateau=True,
                    lr=1e-4,
                    hidden_sizes=HIDDEN_UNITS,
                    num_data_loader_workers=0)
예제 #3
0
class TopicModel():
    def __init__(self,
                 train_size: int = TRAIN_SIZE,
                 n_topics: int = NUM_TOPICS,
                 batch_size: int = BATCH_SIZE,
                 shuffle=True):
        data_loader = TwitterDataset("TRAIN",
                                     train_size,
                                     batch_size=batch_size,
                                     shuffle=shuffle)
        data_loader.init()
        self.data_loader = data_loader
        self.ctm = CTM(input_size=len(self.data_loader.vocab),
                       bert_input_size=512,
                       num_epochs=20,
                       batch_size=MINI_BATCH_SIZE,
                       inference_type="contextual",
                       n_components=n_topics,
                       reduce_on_plateau=True,
                       lr=1e-4,
                       hidden_sizes=HIDDEN_UNITS,
                       num_data_loader_workers=0)

    def train(self):
        for i in range(self.data_loader.n_batches):
            logger.info(f"Starting Batch {i+1}")
            self.data_loader.encoded_next_batch()
            training_dataset = CTMDataset(self.data_loader.bow,
                                          self.data_loader.bert_embeddings,
                                          self.data_loader.idx2token)
            self.ctm.fit(training_dataset)
            logger.info("\n---------------------------")
            logger.info("--------Topics---------")
            topics = self.ctm.get_topic_lists(10)
            for t in topics:
                logger.info(t)
            logger.info("---------------------------\n")
        return self.ctm.get_topic_lists(10)
def test_training(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_thetas(training_dataset)
    assert len(thetas) == len(train_bert)
import torch
import sys
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.evaluation.measures import CoherenceNPMI
from contextualized_topic_models.utils.data_preparation import TextHandler


def show_topics(topic_list):
    for idx, topic_tokens in enumerate(topic_list):
        print(idx)
        print(' '.join(topic_tokens))


if len(sys.argv) < 2:
    raise Exception("Usage: python {} {}".format(sys.argv[0], "<model_file>"))

handler = TextHandler(
    "contextualized_topic_models/data/wiki/wiki_train_en_prep.txt")
handler.prepare()
ctm = CTM(input_size=len(handler.vocab), inference_type="contextual", bert_input_size=768, \
    num_data_loader_workers=1)
ctm.load(sys.argv[1], sys.argv[2])
with open("contextualized_topic_models/data/wiki/wiki_train_en_prep.txt",
          "r") as en:
    texts = [doc.split() for doc in en.read().splitlines()]

# obtain NPMI coherences on the topic modeled documents
show_topics(ctm.get_topic_lists(10))
npmi = CoherenceNPMI(texts=texts, topics=ctm.get_topic_lists(10))
print(npmi.score())
def test_training_all_classes_ctm(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = ZeroShotTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = CombinedTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    with open(data_dir + 'sample_text_document') as filino:
        data = filino.readlines()

    handler = TextHandler(sentences=data)
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_list(
        data, "distiluse-base-multilingual-cased")
    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)

    assert len(topics) == 5
    thetas = ctm.get_doc_topic_distribution(training_dataset)

    assert len(thetas) == len(train_bert)

    qt = QuickText("distiluse-base-multilingual-cased",
                   text_for_bow=data,
                   text_for_bert=data)

    dataset = qt.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    qt_from_conf = QuickText("distiluse-base-multilingual-cased", None, None)
    qt_from_conf.load_configuration(qt.bow, qt.data_bert, qt.vocab,
                                    qt.idx2token)
    dataset = qt_from_conf.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5
import os
import numpy as np
import pickle
import torch
from contextualized_topic_models.datasets.dataset import CTMDatasetTopReg
from contextualized_topic_models.utils.data_preparation import TextHandler, labels_from_file

handler = TextHandler(
    "contextualized_topic_models/data/wiki/wiki_train_en_prep.txt")
handler.prepare()

train_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_train_en_unprep.txt', \
       '../sentence-transformers/sentence_transformers/output/training_wiki_topics_4_xlm-roberta-base-2020-10-24_13-38-14/')
labels, nb_labels = labels_from_file(
    'contextualized_topic_models/data/wiki/topic_full_misc.json')
training_dataset = CTMDatasetTopReg(handler.bow, train_bert, handler.idx2token,
                                    labels)

num_topics = 100
# the nb_labels argument turns this from a CTM into a TCCTM
ctm = CTM(input_size=len(handler.vocab),
          bert_input_size=768,
          num_epochs=60,
          hidden_sizes=(100, ),
          inference_type="contextual",
          n_components=num_topics,
          num_data_loader_workers=0,
          nb_labels=nb_labels)
ctm.fit(training_dataset)
ctm.save("models/wiki/wiki_xlmr_en_topics_4_topreg")
import pickle
import torch
import sys
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.datasets.dataset import CTMDataset
from contextualized_topic_models.utils.data_preparation import TextHandler, bert_embeddings_from_file

NUM_TEST_TOKENS = 683563

def show_topics(topic_list):
    for idx, topic_tokens in enumerate(topic_list):
        print(idx)
        print(' '.join(topic_tokens))

if len(sys.argv) < 2:
    raise Exception("Usage: python {} {}".format(sys.argv[0], "<model_file>"))

handler = TextHandler("contextualized_topic_models/data/wiki/wiki_test_en_prep_sub.txt")
# handler = TextHandler("contextualized_topic_models/data/iqos/iqos_corpus_prep_en.txt")
handler.prepare()

ctm = CTM(input_size=len(handler.vocab), inference_type="contextual", bert_input_size=768)
ctm.load(sys.argv[1], sys.argv[2])

test_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_test_en_unprep_sub.txt', \
        sys.argv[3])
testing_dataset = CTMDataset(handler.bow, test_bert, handler.idx2token)

# print(ctm.get_topic_lists(10))
# show_topics(ctm.get_topic_lists(10))
ctm.test(testing_dataset, NUM_TEST_TOKENS)
예제 #9
0
                                                       "<checkpoint>",
                                                       "<sbert_model>"))

handler_en = TextHandler(
    "contextualized_topic_models/data/wiki/wiki_test_en_prep_sub.txt")
# handler = TextHandler("contextualized_topic_models/data/wiki/iqos_corpus_prep_en.txt")
handler_en.prepare()

testing_bert_en = bert_embeddings_from_file(
    "contextualized_topic_models/data/wiki/wiki_test_en_unprep_sub.txt",
    sys.argv[3])
testing_dataset_en = CTMDataset(handler_en.bow, testing_bert_en,
                                handler_en.idx2token)

ctm = CTM(input_size=len(handler_en.vocab),
          inference_type="contextual",
          bert_input_size=768)
# ctm = torch.load(sys.argv[1], map_location="cpu")
ctm.load(sys.argv[1], sys.argv[2])

num_topics = 100
thetas_en = ctm.get_thetas(testing_dataset_en, n_samples=100)
with open("temp/topics_en_simple.txt", 'w') as test_out:
    topics = np.squeeze(np.argmax(thetas_en, axis=1).T)
    for topic in topics:
        test_out.write(str(topic) + '\n')

# randomly shuffled en baseline
# np.random.seed(3)
# np.random.shuffle(thetas_en)