Пример #1
0
    def get_pretrained_model_and_inputs(self):
        model = VisionTextDualEncoderModel.from_vision_text_pretrained(
            "hf-internal-testing/tiny-random-clip",
            "hf-internal-testing/tiny-bert")
        batch_size = 13
        pixel_values = floats_tensor([
            batch_size,
            model.vision_model.config.num_channels,
            model.vision_model.config.image_size,
            model.vision_model.config.image_size,
        ])
        input_ids = ids_tensor([batch_size, 4],
                               model.text_model.config.vocab_size)
        attention_mask = random_attention_mask([batch_size, 4])
        inputs = {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

        return model, inputs
Пример #2
0
    def check_vision_text_dual_encoder_from_pretrained(self,
                                                       text_config,
                                                       input_ids,
                                                       attention_mask,
                                                       vision_config,
                                                       pixel_values=None,
                                                       **kwargs):

        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = VisionTextDualEncoderModel.from_vision_text_pretrained(
            **kwargs)
        model.to(torch_device)
        model.eval()

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape,
                         (input_ids.shape[0], model.config.projection_dim))
        self.assertEqual(output["image_embeds"].shape,
                         (pixel_values.shape[0], model.config.projection_dim))