def test_train_charlm__nocache_load_use_classifier():
    corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB)
    label_dict = corpus.make_label_dictionary()

    glove_embedding: TokenEmbeddings = CharLMEmbeddings('news-forward-fast', use_cache=False)
    document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64,
                                                                         False,
                                                                         False)

    model = TextClassifier(document_embeddings, label_dict, False)

    trainer = TextClassifierTrainer(model, corpus, label_dict, False)
    trainer.train('./results', max_epochs=2)

    sentence = Sentence("Berlin is a really nice city.")

    for s in model.predict(sentence):
        for l in s.labels:
            assert (l.value is not None)
            assert (0.0 <= l.score <= 1.0)
            assert (type(l.score) is float)

        loaded_model = TextClassifier.load_from_file('./results/final-model.pt')

    sentence = Sentence('I love Berlin')
    sentence_empty = Sentence('       ')

    loaded_model.predict(sentence)
    loaded_model.predict([sentence, sentence_empty])
    loaded_model.predict([sentence_empty])

    # clean up results directory
    shutil.rmtree('./results')
def test_train_load_use_classifier(results_base_path, tasks_base_path):

    corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB,
                                           base_path=tasks_base_path)
    label_dict = corpus.make_label_dictionary()

    glove_embedding: WordEmbeddings = WordEmbeddings('en-glove')
    document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings(
        [glove_embedding], 128, 1, False, 64, False, False)

    model = TextClassifier(document_embeddings, label_dict, False)

    trainer = TextClassifierTrainer(model, corpus, label_dict, test_mode=True)
    trainer.train(str(results_base_path), max_epochs=2)

    sentence = Sentence("Berlin is a really nice city.")

    for s in model.predict(sentence):
        for l in s.labels:
            assert (l.value is not None)
            assert (0.0 <= l.score <= 1.0)
            assert (type(l.score) is float)

    loaded_model = TextClassifier.load_from_file(results_base_path /
                                                 'final-model.pt')

    sentence = Sentence('I love Berlin')
    sentence_empty = Sentence('       ')

    loaded_model.predict(sentence)
    loaded_model.predict([sentence, sentence_empty])
    loaded_model.predict([sentence_empty])

    # clean up results directory
    shutil.rmtree(results_base_path)