示例#1
0
    def check_vision_text_output_attention(
        self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
    ):
        vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)

        output = model(
            input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True
        )

        vision_attentions = output.vision_model_output.attentions
        self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers)

        # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
        image_size = to_2tuple(vision_model.config.image_size)
        patch_size = to_2tuple(vision_model.config.patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        seq_len = num_patches + 1
        self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len))

        text_attentions = output.text_model_output.attentions
        self.assertEqual(len(text_attentions), text_config.num_hidden_layers)

        self.assertEqual(
            text_attentions[0].shape[-3:],
            (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
        )
示例#2
0
    def check_save_load(self,
                        text_config,
                        input_ids,
                        attention_mask,
                        vision_config,
                        pixel_values=None,
                        **kwargs):
        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
            **kwargs)

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask)
        out_1 = output[0]

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
            model = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname)

            after_output = model(input_ids=input_ids,
                                 pixel_values=pixel_values,
                                 attention_mask=attention_mask)
            out_2 = after_output[0]
            max_diff = np.amax(np.abs(out_2 - out_1))
            self.assertLessEqual(max_diff, 1e-3)
示例#3
0
    def check_vision_text_dual_encoder_from_pretrained(
        self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
    ):

        vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)

        output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim))
        self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim))
示例#4
0
    def get_pretrained_model_and_inputs(self):
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
            "hf-internal-testing/tiny-random-clip",
            "hf-internal-testing/tiny-bert",
            vision_from_pt=True,
            text_from_pt=True,
        )
        batch_size = 13
        pixel_values = floats_tensor(
            [
                batch_size,
                model.config.vision_config.num_channels,
                model.config.vision_config.image_size,
                model.config.vision_config.image_size,
            ]
        )
        input_ids = ids_tensor([batch_size, 4], model.config.text_config.vocab_size)
        attention_mask = random_attention_mask([batch_size, 4])
        inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}

        return model, inputs