def check_save_load(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs): vision_model, text_model = self.get_vision_text_model( vision_config, text_config) model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model) model.to(torch_device) model.eval() with torch.no_grad(): output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) out_1 = output[0].cpu().numpy() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = VisionTextDualEncoderModel.from_pretrained( tmpdirname).eval() model.to(torch_device) after_output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) out_2 = after_output[0].cpu().numpy() max_diff = np.amax(np.abs(out_2 - out_1)) self.assertLessEqual(max_diff, 1e-5)
def check_vision_text_output_attention(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs): vision_model, text_model = self.get_vision_text_model( vision_config, text_config) model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model) model.to(torch_device) model.eval() output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True) vision_attentions = output.vision_model_output.attentions self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers) # in DEiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) image_size = to_2tuple(vision_model.config.image_size) patch_size = to_2tuple(vision_model.config.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) seq_len = num_patches + 2 self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len)) text_attentions = output.text_model_output.attentions self.assertEqual(len(text_attentions), text_config.num_hidden_layers) self.assertEqual( text_attentions[0].shape[-3:], (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), )
def check_model_from_pretrained_configs(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs): config = VisionTextDualEncoderConfig.from_vision_text_configs( vision_config, text_config) model = VisionTextDualEncoderModel(config) model.to(torch_device) model.eval() output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], config.projection_dim)) self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], config.projection_dim))
def check_vision_text_dual_encoder_model(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs): vision_model, text_model = self.get_vision_text_model( vision_config, text_config) model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model) model.to(torch_device) model.eval() output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim)) self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim))