def test_train_charlm__nocache_load_use_classifier(): corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB) label_dict = corpus.make_label_dictionary() glove_embedding: TokenEmbeddings = CharLMEmbeddings('news-forward-fast', use_cache=False) document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64, False, False) model = TextClassifier(document_embeddings, label_dict, False) trainer = TextClassifierTrainer(model, corpus, label_dict, False) trainer.train('./results', max_epochs=2) sentence = Sentence("Berlin is a really nice city.") for s in model.predict(sentence): for l in s.labels: assert (l.value is not None) assert (0.0 <= l.score <= 1.0) assert (type(l.score) is float) loaded_model = TextClassifier.load_from_file('./results/final-model.pt') sentence = Sentence('I love Berlin') sentence_empty = Sentence(' ') loaded_model.predict(sentence) loaded_model.predict([sentence, sentence_empty]) loaded_model.predict([sentence_empty]) # clean up results directory shutil.rmtree('./results')
def test_train_load_use_classifier(results_base_path, tasks_base_path): corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB, base_path=tasks_base_path) label_dict = corpus.make_label_dictionary() glove_embedding: WordEmbeddings = WordEmbeddings('en-glove') document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings( [glove_embedding], 128, 1, False, 64, False, False) model = TextClassifier(document_embeddings, label_dict, False) trainer = TextClassifierTrainer(model, corpus, label_dict, test_mode=True) trainer.train(str(results_base_path), max_epochs=2) sentence = Sentence("Berlin is a really nice city.") for s in model.predict(sentence): for l in s.labels: assert (l.value is not None) assert (0.0 <= l.score <= 1.0) assert (type(l.score) is float) loaded_model = TextClassifier.load_from_file(results_base_path / 'final-model.pt') sentence = Sentence('I love Berlin') sentence_empty = Sentence(' ') loaded_model.predict(sentence) loaded_model.predict([sentence, sentence_empty]) loaded_model.predict([sentence_empty]) # clean up results directory shutil.rmtree(results_base_path)