def test_model_has_correct_cfg(): model = TranslationTransformer( HFBackboneConfig( pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random")) assert model.hparams.downstream_model_type == "transformers.AutoModelForSeq2SeqLM" assert type(model.cfg) is TranslationConfig
def test_model_has_correct_cfg(): model = MaskedLanguageModelingTransformer( HFBackboneConfig(pretrained_model_name_or_path='prajjwal1/bert-tiny')) assert model.hparams.downstream_model_type == 'transformers.AutoModelForMaskedLM'
def test_model_has_correct_cfg(): model = TextClassificationTransformer( HFBackboneConfig(pretrained_model_name_or_path='bert-base-cased')) assert model.hparams.downstream_model_type == 'transformers.AutoModelForSequenceClassification'
def test_model_has_correct_cfg(): model = QuestionAnsweringTransformer( HFBackboneConfig(pretrained_model_name_or_path='bert-base-cased')) assert model.hparams.downstream_model_type == 'transformers.AutoModelForQuestionAnswering'
def test_model_has_correct_cfg(): model = SummarizationTransformer( HFBackboneConfig(pretrained_model_name_or_path='t5-base')) assert model.hparams.downstream_model_type == 'transformers.AutoModelForSeq2SeqLM' assert type(model.cfg) is SummarizationConfig
def test_model_has_correct_cfg(): model = LanguageModelingTransformer( HFBackboneConfig(pretrained_model_name_or_path="sshleifer/tiny-gpt2")) assert model.hparams.downstream_model_type == "transformers.AutoModelForCausalLM"