Esempio n. 1
0
def test_extract_incorrect_embeddings():
    """ Test if errors are raised when loading incorrect model """
    with pytest.raises(ValueError):
        model = BERTopic(language="Unknown language")
        model.fit(["some document"])
#import dependencies and packages
import numpy as np
import pandas as pd
from copy import deecopy
from bertopic import BERTopic

#load data
df = pd.read_csv(" ")

#Creating Topics
model = BERTopic(language="english")
topics, probs = model.fit_transform(docs)

#Extract the most frequent topics
model.get_topic_freq()

#Get Individual Topics
model.get_topic(0)

model.get_topic(2)

#Visualize Topics
model.visualize_topics()
Esempio n. 3
0
def test_full():
    model = BERTopic(language="english",
                     verbose=True,
                     n_neighbors=5,
                     min_topic_size=5)

    # Test fit
    topics, probs = model.fit_transform(newsgroup_docs)

    for topic in set(topics):
        words = model.get_topic(topic)[:10]
        assert len(words) == 10

    for topic in model.get_topic_freq().Topic:
        words = model.get_topic(topic)[:10]
        assert len(words) == 10

    assert len(model.get_topic_freq()) > 2
    assert probs.shape == (1000, len(model.get_topic_freq()) - 1)
    assert len(model.get_topics()) == len(model.get_topic_freq())

    # Test transform
    doc = "This is a new document to predict."
    topics_test, probs_test = model.transform([doc])

    assert len(probs_test) == len(model.get_topic_freq()) - 1
    assert len(topics_test) == 1

    # Test find topic
    similar_topics, similarity = model.find_topics("query", top_n=2)
    assert len(similar_topics) == 2
    assert len(similarity) == 2
    assert max(similarity) <= 1

    # Test update topics
    topic = model.get_topic(1)[:10]
    model.update_topics(newsgroup_docs,
                        topics,
                        n_gram_range=(2, 2),
                        stop_words="english")
    updated_topic = model.get_topic(1)[:10]
    model.update_topics(newsgroup_docs, topics)
    original_topic = model.get_topic(1)[:10]

    assert topic != updated_topic
    assert topic == original_topic

    # Test topic reduction
    nr_topics = 2
    new_topics, new_probs = model.reduce_topics(newsgroup_docs,
                                                topics,
                                                probs,
                                                nr_topics=nr_topics)

    assert len(model.get_topic_freq()) == nr_topics + 1
    assert len(new_topics) == len(topics)
    assert len(new_probs) == len(probs)
Esempio n. 4
0
def test_topic_reduction(reduced_topics):
    """ Test whether the topics are correctly reduced """
    model = BERTopic()
    nr_topics = reduced_topics + 2
    model.nr_topics = reduced_topics
    old_documents = pd.DataFrame({
        "Document":
        newsgroup_docs,
        "ID":
        range(len(newsgroup_docs)),
        "Topic":
        np.random.randint(-1, nr_topics - 1, len(newsgroup_docs))
    })
    model._update_topic_size(old_documents)
    model._extract_topics(old_documents.copy())
    old_freq = model.get_topic_freq()

    new_documents = model._reduce_topics(old_documents.copy())
    new_freq = model.get_topic_freq()

    assert old_freq.Count.sum() == new_freq.Count.sum()
    assert len(old_freq.Topic.unique()) == len(old_freq)
    assert len(new_freq.Topic.unique()) == len(new_freq)
    assert isinstance(model.mapped_topics, dict)
    assert not set(model.get_topic_freq().Topic).difference(
        set(new_documents.Topic))
    assert model.mapped_topics
Esempio n. 5
0
def test_reduce_dimensionality(embeddings, shape):
    """ Testing whether the dimensionality is reduced to the correct shape """
    model = BERTopic()
    umap_embeddings = model._reduce_dimensionality(embeddings)
    assert umap_embeddings.shape == (shape, 5)
Esempio n. 6
0
def test_extract_embeddings():
    """ Test if only correct models are loaded """
    with pytest.raises(OSError):
        model = BERTopic(bert_model='not_a_model')
        model._extract_embeddings(["Some document"])