def create_and_check_model(self, config, pixel_values, labels): model = DeiTModel(config=config) model.to(torch_device) model.eval() result = model(pixel_values) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model(self, config, pixel_values, labels): model = DeiTModel(config=config) model.to(torch_device) model.eval() result = model(pixel_values) # expected sequence length = num_patches + 2 (we add 2 for the [CLS] and distillation tokens) image_size = to_2tuple(self.image_size) patch_size = to_2tuple(self.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, num_patches + 2, self.hidden_size))
def test_model_from_pretrained(self): for model_name in DEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = DeiTModel.from_pretrained(model_name) self.assertIsNotNone(model)
def get_vision_text_model(self, vision_config, text_config): vision_model = DeiTModel(vision_config).eval() text_model = RobertaModel(text_config).eval() return vision_model, text_model
def get_encoder_decoder_model(self, config, decoder_config): encoder_model = DeiTModel(config).eval() decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model