Пример #1
0
    def check_save_load(self,
                        text_config,
                        input_ids,
                        attention_mask,
                        vision_config,
                        pixel_values=None,
                        **kwargs):
        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        model = VisionTextDualEncoderModel(vision_model=vision_model,
                                           text_model=text_model)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(input_ids=input_ids,
                           pixel_values=pixel_values,
                           attention_mask=attention_mask)
            out_1 = output[0].cpu().numpy()

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model = VisionTextDualEncoderModel.from_pretrained(
                    tmpdirname).eval()
                model.to(torch_device)

                after_output = model(input_ids=input_ids,
                                     pixel_values=pixel_values,
                                     attention_mask=attention_mask)
                out_2 = after_output[0].cpu().numpy()
                max_diff = np.amax(np.abs(out_2 - out_1))
                self.assertLessEqual(max_diff, 1e-5)
Пример #2
0
    def check_vision_text_output_attention(self,
                                           text_config,
                                           input_ids,
                                           attention_mask,
                                           vision_config,
                                           pixel_values=None,
                                           **kwargs):
        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        model = VisionTextDualEncoderModel(vision_model=vision_model,
                                           text_model=text_model)
        model.to(torch_device)
        model.eval()

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask,
                       output_attentions=True)

        vision_attentions = output.vision_model_output.attentions
        self.assertEqual(len(vision_attentions),
                         vision_config.num_hidden_layers)

        # in DEiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
        image_size = to_2tuple(vision_model.config.image_size)
        patch_size = to_2tuple(vision_model.config.patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
                                                          patch_size[0])
        seq_len = num_patches + 2
        self.assertEqual(vision_attentions[0].shape[-3:],
                         (vision_config.num_attention_heads, seq_len, seq_len))

        text_attentions = output.text_model_output.attentions
        self.assertEqual(len(text_attentions), text_config.num_hidden_layers)

        self.assertEqual(
            text_attentions[0].shape[-3:],
            (text_config.num_attention_heads, input_ids.shape[-1],
             input_ids.shape[-1]),
        )
Пример #3
0
    def check_model_from_pretrained_configs(self,
                                            text_config,
                                            input_ids,
                                            attention_mask,
                                            vision_config,
                                            pixel_values=None,
                                            **kwargs):
        config = VisionTextDualEncoderConfig.from_vision_text_configs(
            vision_config, text_config)

        model = VisionTextDualEncoderModel(config)
        model.to(torch_device)
        model.eval()

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape,
                         (input_ids.shape[0], config.projection_dim))
        self.assertEqual(output["image_embeds"].shape,
                         (pixel_values.shape[0], config.projection_dim))
Пример #4
0
    def check_vision_text_dual_encoder_model(self,
                                             text_config,
                                             input_ids,
                                             attention_mask,
                                             vision_config,
                                             pixel_values=None,
                                             **kwargs):
        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        model = VisionTextDualEncoderModel(vision_model=vision_model,
                                           text_model=text_model)
        model.to(torch_device)
        model.eval()

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape,
                         (input_ids.shape[0], model.config.projection_dim))
        self.assertEqual(output["image_embeds"].shape,
                         (pixel_values.shape[0], model.config.projection_dim))