def check_vision_text_output_attention( self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs ): vision_model, text_model = self.get_vision_text_model(vision_config, text_config) kwargs = {"vision_model": vision_model, "text_model": text_model} model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs) output = model( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True ) vision_attentions = output.vision_model_output.attentions self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers) # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) image_size = to_2tuple(vision_model.config.image_size) patch_size = to_2tuple(vision_model.config.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) seq_len = num_patches + 1 self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len)) text_attentions = output.text_model_output.attentions self.assertEqual(len(text_attentions), text_config.num_hidden_layers) self.assertEqual( text_attentions[0].shape[-3:], (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), )
def check_save_load(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs): vision_model, text_model = self.get_vision_text_model( vision_config, text_config) kwargs = {"vision_model": vision_model, "text_model": text_model} model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained( **kwargs) output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) out_1 = output[0] with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname) after_output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) out_2 = after_output[0] max_diff = np.amax(np.abs(out_2 - out_1)) self.assertLessEqual(max_diff, 1e-3)
def check_vision_text_dual_encoder_from_pretrained( self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs ): vision_model, text_model = self.get_vision_text_model(vision_config, text_config) kwargs = {"vision_model": vision_model, "text_model": text_model} model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs) output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim)) self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim))
def get_pretrained_model_and_inputs(self): model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained( "hf-internal-testing/tiny-random-clip", "hf-internal-testing/tiny-bert", vision_from_pt=True, text_from_pt=True, ) batch_size = 13 pixel_values = floats_tensor( [ batch_size, model.config.vision_config.num_channels, model.config.vision_config.image_size, model.config.vision_config.image_size, ] ) input_ids = ids_tensor([batch_size, 4], model.config.text_config.vocab_size) attention_mask = random_attention_mask([batch_size, 4]) inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask} return model, inputs