Exemplo n.º 1
0
def test_train_resume_classifier(results_base_path, tasks_base_path):
    corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb",
                                                 label_type="topic")
    label_dict = corpus.make_label_dictionary(label_type="topic")

    model = TextClassifier(
        document_embeddings=document_embeddings,
        label_dictionary=label_dict,
        multi_label=False,
        label_type="topic",
    )

    # train model for 2 epochs
    trainer = ModelTrainer(model, corpus)
    trainer.train(results_base_path,
                  max_epochs=2,
                  shuffle=False,
                  checkpoint=True)

    del model

    # load the checkpoint model and train until epoch 4
    checkpoint_model = TextClassifier.load(results_base_path / "checkpoint.pt")
    with pytest.warns(UserWarning):
        trainer.resume(model=checkpoint_model, max_epochs=4)

    del trainer
Exemplo n.º 2
0
def test_train_resume_tagger(results_base_path, tasks_base_path):

    corpus_1 = flair.datasets.ColumnCorpus(data_folder=tasks_base_path /
                                           "fashion",
                                           column_format={
                                               0: "text",
                                               3: "ner"
                                           })
    corpus_2 = flair.datasets.NER_GERMAN_GERMEVAL(
        base_path=tasks_base_path).downsample(0.1)

    corpus = MultiCorpus([corpus_1, corpus_2])
    tag_dictionary = corpus.make_label_dictionary("ner")

    model: SequenceTagger = SequenceTagger(
        hidden_size=64,
        embeddings=turian_embeddings,
        tag_dictionary=tag_dictionary,
        tag_type="ner",
        use_crf=False,
    )

    # train model for 2 epochs
    trainer = ModelTrainer(model, corpus)
    trainer.train(results_base_path,
                  max_epochs=2,
                  shuffle=False,
                  checkpoint=True)

    del model

    # load the checkpoint model and train until epoch 4
    checkpoint_model = SequenceTagger.load(results_base_path / "checkpoint.pt")
    trainer.resume(model=checkpoint_model, max_epochs=4)

    # clean up results directory
    del trainer