Exemplo n.º 1
0
def test_text_classifier_transformer_finetune(results_base_path,
                                              tasks_base_path):
    flair.set_seed(123)

    corpus = ClassificationCorpus(
        tasks_base_path / "trivial" / "trivial_text_classification_single",
        label_type="city",
    )
    label_dict = corpus.make_label_dictionary(label_type="city")

    model: TextClassifier = TextClassifier(
        document_embeddings=TransformerDocumentEmbeddings(
            "distilbert-base-uncased"),
        label_dictionary=label_dict,
        label_type="city",
        multi_label=False,
    )

    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune(
        results_base_path,
        mini_batch_size=2,
        max_epochs=10,
        shuffle=True,
        learning_rate=0.5e-5,
        num_workers=2,
    )

    # check if model can predict
    sentence = Sentence("this is Berlin")
    sentence_empty = Sentence("       ")

    model.predict(sentence)
    model.predict([sentence, sentence_empty])
    model.predict([sentence_empty])

    # load model
    loaded_model = TextClassifier.load(results_base_path / "final-model.pt")

    # chcek if model predicts correct label
    sentence = Sentence("this is Berlin")
    sentence_empty = Sentence("       ")

    loaded_model.predict([sentence, sentence_empty])

    values = []
    for label in sentence.labels:
        assert label.value is not None
        assert 0.0 <= label.score <= 1.0
        assert type(label.score) is float
        values.append(label.value)

    assert "Berlin" in values

    # check if loaded model successfully fit the training data
    result: Result = loaded_model.evaluate(corpus.test, gold_label_type="city")
    assert result.classification_report["micro avg"]["f1-score"] == 1.0

    del loaded_model
Exemplo n.º 2
0
def test_sequence_tagger_transformer_finetune(results_base_path,
                                              tasks_base_path):
    flair.set_seed(123)

    # load dataset
    corpus: Corpus = ColumnCorpus(
        data_folder=tasks_base_path / "trivial" / "trivial_bioes",
        column_format={
            0: "text",
            1: "ner"
        },
    )
    tag_dictionary = corpus.make_label_dictionary("ner")

    # tagger without CRF
    tagger: SequenceTagger = SequenceTagger(
        hidden_size=64,
        embeddings=TransformerWordEmbeddings("distilbert-base-uncased",
                                             fine_tune=True),
        tag_dictionary=tag_dictionary,
        tag_type="ner",
        use_crf=False,
        use_rnn=False,
        reproject_embeddings=False,
    )

    # train
    trainer = ModelTrainer(tagger, corpus)
    trainer.fine_tune(
        results_base_path,
        mini_batch_size=2,
        max_epochs=10,
        shuffle=True,
        learning_rate=0.5e-4,
    )

    loaded_model: SequenceTagger = SequenceTagger.load(results_base_path /
                                                       "final-model.pt")

    sentence = Sentence("this is New York")
    sentence_empty = Sentence("       ")

    loaded_model.predict(sentence)
    loaded_model.predict([sentence, sentence_empty])
    loaded_model.predict([sentence_empty])

    # check if loaded model can predict
    entities = [span.text for span in sentence.get_spans("ner")]
    assert "New York" in entities

    # check if loaded model successfully fit the training data
    result: Result = loaded_model.evaluate(corpus.test, gold_label_type="ner")
    assert result.classification_report["micro avg"]["f1-score"] == 1.0

    del loaded_model
Exemplo n.º 3
0
def main():
    parser = HfArgumentParser((ModelArguments, TrainingArguments, FlertArguments, DataArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        (
            model_args,
            training_args,
            flert_args,
            data_args,
        ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        (
            model_args,
            training_args,
            flert_args,
            data_args,
        ) = parser.parse_args_into_dataclasses()

    set_seed(training_args.seed)

    flair.device = training_args.device

    corpus = get_flair_corpus(data_args)

    logger.info(corpus)

    tag_type: str = "ner"
    tag_dictionary = corpus.make_label_dictionary(tag_type)
    logger.info(tag_dictionary)

    embeddings = TransformerWordEmbeddings(
        model=model_args.model_name_or_path,
        layers=model_args.layers,
        subtoken_pooling=model_args.subtoken_pooling,
        fine_tune=True,
        use_context=flert_args.context_size,
        respect_document_boundaries=flert_args.respect_document_boundaries,
    )

    tagger = SequenceTagger(
        hidden_size=model_args.hidden_size,
        embeddings=embeddings,
        tag_dictionary=tag_dictionary,
        tag_type=tag_type,
        use_crf=model_args.use_crf,
        use_rnn=False,
        reproject_embeddings=False,
    )

    trainer = ModelTrainer(tagger, corpus)

    trainer.fine_tune(
        data_args.output_dir,
        learning_rate=training_args.learning_rate,
        mini_batch_size=training_args.batch_size,
        mini_batch_chunk_size=training_args.mini_batch_chunk_size,
        max_epochs=training_args.num_epochs,
        embeddings_storage_mode=training_args.embeddings_storage_mode,
        weight_decay=training_args.weight_decay,
    )

    torch.save(model_args, os.path.join(data_args.output_dir, "model_args.bin"))
    torch.save(training_args, os.path.join(data_args.output_dir, "training_args.bin"))

    # finally, print model card for information
    tagger.print_model_card()