def test_train_resume_classifier(results_base_path, tasks_base_path): corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb", label_type="topic") label_dict = corpus.make_label_dictionary(label_type="topic") model = TextClassifier( document_embeddings=document_embeddings, label_dictionary=label_dict, multi_label=False, label_type="topic", ) # train model for 2 epochs trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) del model # load the checkpoint model and train until epoch 4 checkpoint_model = TextClassifier.load(results_base_path / "checkpoint.pt") with pytest.warns(UserWarning): trainer.resume(model=checkpoint_model, max_epochs=4) del trainer
def test_train_resume_tagger(results_base_path, tasks_base_path): corpus_1 = flair.datasets.ColumnCorpus(data_folder=tasks_base_path / "fashion", column_format={ 0: "text", 3: "ner" }) corpus_2 = flair.datasets.NER_GERMAN_GERMEVAL( base_path=tasks_base_path).downsample(0.1) corpus = MultiCorpus([corpus_1, corpus_2]) tag_dictionary = corpus.make_label_dictionary("ner") model: SequenceTagger = SequenceTagger( hidden_size=64, embeddings=turian_embeddings, tag_dictionary=tag_dictionary, tag_type="ner", use_crf=False, ) # train model for 2 epochs trainer = ModelTrainer(model, corpus) trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True) del model # load the checkpoint model and train until epoch 4 checkpoint_model = SequenceTagger.load(results_base_path / "checkpoint.pt") trainer.resume(model=checkpoint_model, max_epochs=4) # clean up results directory del trainer