def test_training_with_saved_data(data_dir):
    handler = TextHandler(data_dir + "gnews/GoogleNews.txt")
    handler.prepare()  # create vocabulary and training data

    # load BERT data
    with open(data_dir + "gnews/bert_embeddings_gnews", "rb") as filino:
        training_bert = pickle.load(filino)

    training_dataset = CTMDataset(handler.bow, training_bert,
                                  handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=768,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model

    print(ctm.get_topics(2))

    ctm.get_doc_topic_distribution(training_dataset)
    assert isclose(np.sum(ctm.get_topic_word_distribution()[0]),
                   1,
                   rel_tol=1e-5,
                   abs_tol=0.0)
def test_training_all_classes_ctm(data_dir):
    handler = TextHandler(data_dir + "sample_text_document")
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_file(
        data_dir + 'sample_text_document', "distiluse-base-multilingual-cased")

    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)

    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = ZeroShotTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    ctm = CombinedTM(input_size=len(handler.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    thetas = ctm.get_doc_topic_distribution(training_dataset)
    assert len(thetas) == len(train_bert)

    with open(data_dir + 'sample_text_document') as filino:
        data = filino.readlines()

    handler = TextHandler(sentences=data)
    handler.prepare()  # create vocabulary and training data

    train_bert = bert_embeddings_from_list(
        data, "distiluse-base-multilingual-cased")
    training_dataset = CTMDataset(handler.bow, train_bert, handler.idx2token)

    ctm = CTM(input_size=len(handler.vocab),
              bert_input_size=512,
              num_epochs=1,
              inference_type="combined",
              n_components=5)
    ctm.fit(training_dataset)  # run the model
    topics = ctm.get_topic_lists(2)

    assert len(topics) == 5
    thetas = ctm.get_doc_topic_distribution(training_dataset)

    assert len(thetas) == len(train_bert)

    qt = QuickText("distiluse-base-multilingual-cased",
                   text_for_bow=data,
                   text_for_bert=data)

    dataset = qt.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5

    qt_from_conf = QuickText("distiluse-base-multilingual-cased", None, None)
    qt_from_conf.load_configuration(qt.bow, qt.data_bert, qt.vocab,
                                    qt.idx2token)
    dataset = qt_from_conf.load_dataset()

    ctm = ZeroShotTM(input_size=len(qt.vocab),
                     bert_input_size=512,
                     num_epochs=1,
                     n_components=5)
    ctm.fit(dataset)  # run the model
    topics = ctm.get_topic_lists(2)
    assert len(topics) == 5