Ejemplo n.º 1
0
    def load_dataset(self):
        if self.loaded_from_config:
            training_dataset = CTMDataset(self.bow, self.data_bert, self.idx2token)
        else:
            self.prepare_bow()

            if self.data_bert is None:
                if self.text_for_bert is not None:
                    self.data_bert = bert_embeddings_from_list(self.text_for_bert, self.bert_model)
                else:
                    self.data_bert = bert_embeddings_from_list(self.text_for_bow, self.bert_model)

            training_dataset = CTMDataset(self.bow, self.data_bert, self.idx2token)

        return training_dataset
def test_training_with_saved_data(data_dir):
    handler = TextHandler(data_dir + "gnews/GoogleNews.txt")
    handler.prepare()  # create vocabulary and training data

    # load BERT data
    with open(data_dir + "gnews/bert_embeddings_gnews", "rb") as filino:
        training_bert = pickle.load(filino)

    training_dataset = CTMDataset(handler.bow, training_bert,
                                  handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=768,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model

    print(ctm.get_topics(2))

    ctm.get_doc_topic_distribution(training_dataset)
    assert isclose(np.sum(ctm.get_topic_word_distribution()[0]),
                   1,
                   rel_tol=1e-5,
                   abs_tol=0.0)
Ejemplo n.º 3
0
    def load_dataset(self):
        self.prepare_bow()

        if self.unpreprocessed_sentences is not None:
            testing_bert = bert_embeddings_from_list(
                self.unpreprocessed_sentences, self.bert_model)
        else:
            testing_bert = bert_embeddings_from_list(
                self.preprocessed_sentences, self.bert_model)

        training_dataset = CTMDataset(self.bow, testing_bert, self.idx2token)
        return training_dataset
Ejemplo n.º 4
0
    def create_test_set(self, text_for_contextual, text_for_bow=None):

        if self.contextualized_model is None:
            raise Exception("You should define a contextualized model if you want to create the embeddings")

        if text_for_bow is not None:
            test_bow_embeddings = self.vectorizer.transform(text_for_bow)
        else:
            # dummy matrix
            test_bow_embeddings = scipy.sparse.csr_matrix(np.zeros((len(text_for_contextual), 1)))
        test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model)

        return CTMDataset(test_bow_embeddings, test_contextualized_embeddings, self.id2token)
Ejemplo n.º 5
0
    def create_training_set(self, text_for_contextual, text_for_bow):

        if self.contextualized_model is None:
            raise Exception("You should define a contextualized model if you want to create the embeddings")

        # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users
        self.vectorizer = CountVectorizer()

        train_bow_embeddings = self.vectorizer.fit_transform(text_for_bow)
        train_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model)
        self.vocab = self.vectorizer.get_feature_names()
        self.id2token = {k: v for k, v in zip(range(0, len(self.vocab)), self.vocab)}

        return CTMDataset(train_bow_embeddings, train_contextualized_embeddings, self.id2token)
Ejemplo n.º 6
0
 def train(self):
     for i in range(self.data_loader.n_batches):
         logger.info(f"Starting Batch {i+1}")
         self.data_loader.encoded_next_batch()
         training_dataset = CTMDataset(self.data_loader.bow,
                                       self.data_loader.bert_embeddings,
                                       self.data_loader.idx2token)
         self.ctm.fit(training_dataset)
         logger.info("\n---------------------------")
         logger.info("--------Topics---------")
         topics = self.ctm.get_topic_lists(10)
         for t in topics:
             logger.info(t)
         logger.info("---------------------------\n")
     return self.ctm.get_topic_lists(10)
def test_training(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_thetas(training_dataset)
    assert len(thetas) == len(train_bert)
def test_training_all_classes_ctm(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = ZeroShotTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = CombinedTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    with open(data_dir + 'sample_text_document') as filino:
        data = filino.readlines()

    handler = TextHandler(sentences=data)
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_list(
        data, "distiluse-base-multilingual-cased")
    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)

    assert len(topics) == 5
    thetas = ctm.get_doc_topic_distribution(training_dataset)

    assert len(thetas) == len(train_bert)

    qt = QuickText("distiluse-base-multilingual-cased",
                   text_for_bow=data,
                   text_for_bert=data)

    dataset = qt.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    qt_from_conf = QuickText("distiluse-base-multilingual-cased", None, None)
    qt_from_conf.load_configuration(qt.bow, qt.data_bert, qt.vocab,
                                    qt.idx2token)
    dataset = qt_from_conf.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5
Ejemplo n.º 9
0
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.utils.data_preparation import bert_embeddings_from_file, bert_embeddings_from_list
from contextualized_topic_models.evaluation.measures import CoherenceNPMI
import os
import numpy as np
import pickle
import torch
from contextualized_topic_models.datasets.dataset import CTMDataset
from contextualized_topic_models.utils.data_preparation import TextHandler

handler = TextHandler("contextualized_topic_models/data/wiki/wiki_train_en_prep.txt")
handler.prepare()

train_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_train_en_unprep.txt', \
        '../sentence-transformers/sentence_transformers/output/training_wiki_topics_4_xlm-roberta-base-2020-10-24_13-38-14')
training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

num_topics = 100
ctm = CTM(input_size=len(handler.vocab), bert_input_size=768, num_epochs=60, hidden_sizes=(100,),
          inference_type="contextual", n_components=num_topics, num_data_loader_workers=0)
ctm.fit(training_dataset)
ctm.save("models/wiki/wiki_xlmr_en_topics_4")
 def load(self, contextualized_embeddings, bow_embeddings, id2token):
     return CTMDataset(bow_embeddings, contextualized_embeddings, id2token)
import pickle
import torch
import sys
from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.datasets.dataset import CTMDataset
from contextualized_topic_models.utils.data_preparation import TextHandler, bert_embeddings_from_file

NUM_TEST_TOKENS = 683563

def show_topics(topic_list):
    for idx, topic_tokens in enumerate(topic_list):
        print(idx)
        print(' '.join(topic_tokens))

if len(sys.argv) < 2:
    raise Exception("Usage: python {} {}".format(sys.argv[0], "<model_file>"))

handler = TextHandler("contextualized_topic_models/data/wiki/wiki_test_en_prep_sub.txt")
# handler = TextHandler("contextualized_topic_models/data/iqos/iqos_corpus_prep_en.txt")
handler.prepare()

ctm = CTM(input_size=len(handler.vocab), inference_type="contextual", bert_input_size=768)
ctm.load(sys.argv[1], sys.argv[2])

test_bert = bert_embeddings_from_file('contextualized_topic_models/data/wiki/wiki_test_en_unprep_sub.txt', \
        sys.argv[3])
testing_dataset = CTMDataset(handler.bow, test_bert, handler.idx2token)

# print(ctm.get_topic_lists(10))
# show_topics(ctm.get_topic_lists(10))
ctm.test(testing_dataset, NUM_TEST_TOKENS)
Ejemplo n.º 12
0
if len(sys.argv) < 4:
    raise Exception("Usage: python {} {} {} {}".format(sys.argv[0],
                                                       "<ctm_model>",
                                                       "<checkpoint>",
                                                       "<sbert_model>"))

handler_en = TextHandler(
    "contextualized_topic_models/data/wiki/wiki_test_en_prep_sub.txt")
# handler = TextHandler("contextualized_topic_models/data/wiki/iqos_corpus_prep_en.txt")
handler_en.prepare()

testing_bert_en = bert_embeddings_from_file(
    "contextualized_topic_models/data/wiki/wiki_test_en_unprep_sub.txt",
    sys.argv[3])
testing_dataset_en = CTMDataset(handler_en.bow, testing_bert_en,
                                handler_en.idx2token)

ctm = CTM(input_size=len(handler_en.vocab),
          inference_type="contextual",
          bert_input_size=768)
# ctm = torch.load(sys.argv[1], map_location="cpu")
ctm.load(sys.argv[1], sys.argv[2])

num_topics = 100
thetas_en = ctm.get_thetas(testing_dataset_en, n_samples=100)
with open("temp/topics_en_simple.txt", 'w') as test_out:
    topics = np.squeeze(np.argmax(thetas_en, axis=1).T)
    for topic in topics:
        test_out.write(str(topic) + '\n')

# randomly shuffled en baseline