def test_model_from_pretrained(self): model = TFViTModel.from_pretrained("google/vit-base-patch16-224") self.assertIsNotNone(model)
def get_encoder_decoder_models(self): encoder_model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k", name="encoder") decoder_model = TFGPT2LMHeadModel.from_pretrained("../gpt2", config=self.get_decoder_config(), name="decoder") return {"encoder": encoder_model, "decoder": decoder_model}