def test_text_classifier_transformer_finetune(results_base_path, tasks_base_path): flair.set_seed(123) corpus = ClassificationCorpus( tasks_base_path / "trivial" / "trivial_text_classification_single", label_type="city", ) label_dict = corpus.make_label_dictionary(label_type="city") model: TextClassifier = TextClassifier( document_embeddings=TransformerDocumentEmbeddings( "distilbert-base-uncased"), label_dictionary=label_dict, label_type="city", multi_label=False, ) trainer = ModelTrainer(model, corpus) trainer.fine_tune( results_base_path, mini_batch_size=2, max_epochs=10, shuffle=True, learning_rate=0.5e-5, num_workers=2, ) # check if model can predict sentence = Sentence("this is Berlin") sentence_empty = Sentence(" ") model.predict(sentence) model.predict([sentence, sentence_empty]) model.predict([sentence_empty]) # load model loaded_model = TextClassifier.load(results_base_path / "final-model.pt") # chcek if model predicts correct label sentence = Sentence("this is Berlin") sentence_empty = Sentence(" ") loaded_model.predict([sentence, sentence_empty]) values = [] for label in sentence.labels: assert label.value is not None assert 0.0 <= label.score <= 1.0 assert type(label.score) is float values.append(label.value) assert "Berlin" in values # check if loaded model successfully fit the training data result: Result = loaded_model.evaluate(corpus.test, gold_label_type="city") assert result.classification_report["micro avg"]["f1-score"] == 1.0 del loaded_model
def test_sequence_tagger_transformer_finetune(results_base_path, tasks_base_path): flair.set_seed(123) # load dataset corpus: Corpus = ColumnCorpus( data_folder=tasks_base_path / "trivial" / "trivial_bioes", column_format={ 0: "text", 1: "ner" }, ) tag_dictionary = corpus.make_label_dictionary("ner") # tagger without CRF tagger: SequenceTagger = SequenceTagger( hidden_size=64, embeddings=TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True), tag_dictionary=tag_dictionary, tag_type="ner", use_crf=False, use_rnn=False, reproject_embeddings=False, ) # train trainer = ModelTrainer(tagger, corpus) trainer.fine_tune( results_base_path, mini_batch_size=2, max_epochs=10, shuffle=True, learning_rate=0.5e-4, ) loaded_model: SequenceTagger = SequenceTagger.load(results_base_path / "final-model.pt") sentence = Sentence("this is New York") sentence_empty = Sentence(" ") loaded_model.predict(sentence) loaded_model.predict([sentence, sentence_empty]) loaded_model.predict([sentence_empty]) # check if loaded model can predict entities = [span.text for span in sentence.get_spans("ner")] assert "New York" in entities # check if loaded model successfully fit the training data result: Result = loaded_model.evaluate(corpus.test, gold_label_type="ner") assert result.classification_report["micro avg"]["f1-score"] == 1.0 del loaded_model
def main(): parser = HfArgumentParser((ModelArguments, TrainingArguments, FlertArguments, DataArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): ( model_args, training_args, flert_args, data_args, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: ( model_args, training_args, flert_args, data_args, ) = parser.parse_args_into_dataclasses() set_seed(training_args.seed) flair.device = training_args.device corpus = get_flair_corpus(data_args) logger.info(corpus) tag_type: str = "ner" tag_dictionary = corpus.make_label_dictionary(tag_type) logger.info(tag_dictionary) embeddings = TransformerWordEmbeddings( model=model_args.model_name_or_path, layers=model_args.layers, subtoken_pooling=model_args.subtoken_pooling, fine_tune=True, use_context=flert_args.context_size, respect_document_boundaries=flert_args.respect_document_boundaries, ) tagger = SequenceTagger( hidden_size=model_args.hidden_size, embeddings=embeddings, tag_dictionary=tag_dictionary, tag_type=tag_type, use_crf=model_args.use_crf, use_rnn=False, reproject_embeddings=False, ) trainer = ModelTrainer(tagger, corpus) trainer.fine_tune( data_args.output_dir, learning_rate=training_args.learning_rate, mini_batch_size=training_args.batch_size, mini_batch_chunk_size=training_args.mini_batch_chunk_size, max_epochs=training_args.num_epochs, embeddings_storage_mode=training_args.embeddings_storage_mode, weight_decay=training_args.weight_decay, ) torch.save(model_args, os.path.join(data_args.output_dir, "model_args.bin")) torch.save(training_args, os.path.join(data_args.output_dir, "training_args.bin")) # finally, print model card for information tagger.print_model_card()