def test_text_classifier_mulit_label():
    corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB)
    label_dict = corpus.make_label_dictionary()

    glove_embedding: WordEmbeddings = WordEmbeddings('en-glove')
    document_embeddings: DocumentMeanEmbeddings = DocumentMeanEmbeddings(
        [glove_embedding], True)

    model = TextClassifier(document_embeddings, label_dict, True)

    trainer = TextClassifierTrainer(model, corpus, label_dict, False)
    trainer.train('./results', max_epochs=2)

    # clean up results directory
    shutil.rmtree('./results')
Esempio n. 2
0
def test_document_mean_embeddings():
    text = 'I love Berlin. Berlin is a great place to live.'
    sentence: Sentence = Sentence(text)

    glove: TokenEmbeddings = WordEmbeddings('en-glove')
    charlm: TokenEmbeddings = CharLMEmbeddings('mix-backward')

    embeddings: DocumentMeanEmbeddings = DocumentMeanEmbeddings(
        [glove, charlm])

    embeddings.embed(sentence)

    assert (len(sentence.get_embedding()) != 0)

    sentence.clear_embeddings()

    assert (len(sentence.get_embedding()) == 0)
Esempio n. 3
0
def test_text_classifier_mulit_label():
    corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB)
    label_dict = corpus.make_label_dictionary()

    glove_embedding: WordEmbeddings = WordEmbeddings('en-glove')
    document_embeddings: DocumentMeanEmbeddings = DocumentMeanEmbeddings([glove_embedding], True)

    model = TextClassifier(document_embeddings, label_dict, True)

    trainer = TextClassifierTrainer(model, corpus, label_dict, False)
    trainer.train('./results', max_epochs=2)

    sentence = Sentence("Berlin is a really nice city.")

    for s in model.predict(sentence):
        for l in s.labels:
            assert(l.name is not None)
            assert(0.0 <= l.confidence <= 1.0)
            assert(type(l.confidence) is float)

    # clean up results directory
    shutil.rmtree('./results')