Exemplo n.º 1
0
    def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):

        config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

        pt_model = VisionTextDualEncoderModel(config)
        fx_model = FlaxVisionTextDualEncoderModel(config)

        fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
        fx_model.params = fx_state

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
Exemplo n.º 2
0
    def test_inference(self):
        model = VisionTextDualEncoderModel.from_pretrained(
            "clip-italian/clip-italian", logit_scale_init_value=1)
        processor = VisionTextDualEncoderProcessor.from_pretrained(
            "clip-italian/clip-italian")

        image = Image.open(
            "./tests/fixtures/tests_samples/COCO/000000039769.png")
        inputs = processor(
            text=["una foto di un gatto", "una foto di un cane"],
            images=image,
            padding=True,
            return_tensors="pt")

        outputs = model(**inputs)

        # verify the logits
        self.assertEqual(
            outputs.logits_per_image.shape,
            (inputs.pixel_values.shape[0], inputs.input_ids.shape[0]))
        self.assertEqual(
            outputs.logits_per_text.shape,
            (inputs.input_ids.shape[0], inputs.pixel_values.shape[0]),
        )

        expected_logits = torch.tensor([[1.2284727, 0.3104122]])

        self.assertTrue(
            torch.allclose(outputs.logits_per_image,
                           expected_logits,
                           atol=1e-3))
Exemplo n.º 3
0
    def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids,
                                  attention_mask, pixel_values, **kwargs):

        pt_model.to(torch_device)
        pt_model.eval()

        # prepare inputs
        inputs_dict = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values
        }
        pt_inputs = inputs_dict
        flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}

        with torch.no_grad():
            pt_outputs = pt_model(**pt_inputs).to_tuple()

        fx_outputs = fx_model(**flax_inputs).to_tuple()
        self.assertEqual(len(fx_outputs), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
            self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)

        # PT -> Flax
        with tempfile.TemporaryDirectory() as tmpdirname:
            pt_model.save_pretrained(tmpdirname)
            fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(
                tmpdirname, from_pt=True)

        fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
        self.assertEqual(len(fx_outputs_loaded), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4],
                                               pt_outputs[:4]):
            self.assert_almost_equals(fx_output_loaded, pt_output.numpy(),
                                      4e-2)

        # Flax -> PT
        with tempfile.TemporaryDirectory() as tmpdirname:
            fx_model.save_pretrained(tmpdirname)
            pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(
                tmpdirname, from_flax=True)

        pt_model_loaded.to(torch_device)
        pt_model_loaded.eval()

        with torch.no_grad():
            pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

        self.assertEqual(len(fx_outputs), len(pt_outputs_loaded),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output, pt_output_loaded in zip(fx_outputs[:4],
                                               pt_outputs_loaded[:4]):
            self.assert_almost_equals(fx_output, pt_output_loaded.numpy(),
                                      4e-2)
Exemplo n.º 4
0
    def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):

        config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

        pt_model = VisionTextDualEncoderModel(config)
        fx_model = FlaxVisionTextDualEncoderModel(config)

        pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
Exemplo n.º 5
0
    def test_real_model_save_load_from_pretrained(self):
        model_2, inputs = self.get_pretrained_model_and_inputs()
        model_2.to(torch_device)

        with torch.no_grad():
            outputs = model_2(**inputs)
            out_2 = outputs[0].cpu().numpy()

            with tempfile.TemporaryDirectory() as tmp_dirname:
                model_2.save_pretrained(tmp_dirname)
                model_1 = VisionTextDualEncoderModel.from_pretrained(
                    tmp_dirname)
                model_1.to(torch_device)

                after_outputs = model_1(**inputs)
                out_1 = after_outputs[0].cpu().numpy()
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
Exemplo n.º 6
0
    def get_pretrained_model_and_inputs(self):
        model = VisionTextDualEncoderModel.from_vision_text_pretrained(
            "hf-internal-testing/tiny-random-clip",
            "hf-internal-testing/tiny-bert")
        batch_size = 13
        pixel_values = floats_tensor([
            batch_size,
            model.vision_model.config.num_channels,
            model.vision_model.config.image_size,
            model.vision_model.config.image_size,
        ])
        input_ids = ids_tensor([batch_size, 4],
                               model.text_model.config.vocab_size)
        attention_mask = random_attention_mask([batch_size, 4])
        inputs = {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

        return model, inputs
Exemplo n.º 7
0
    def check_vision_text_dual_encoder_from_pretrained(self,
                                                       text_config,
                                                       input_ids,
                                                       attention_mask,
                                                       vision_config,
                                                       pixel_values=None,
                                                       **kwargs):

        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = VisionTextDualEncoderModel.from_vision_text_pretrained(
            **kwargs)
        model.to(torch_device)
        model.eval()

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape,
                         (input_ids.shape[0], model.config.projection_dim))
        self.assertEqual(output["image_embeds"].shape,
                         (pixel_values.shape[0], model.config.projection_dim))
Exemplo n.º 8
0
    def check_vision_text_output_attention(self,
                                           text_config,
                                           input_ids,
                                           attention_mask,
                                           vision_config,
                                           pixel_values=None,
                                           **kwargs):
        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        model = VisionTextDualEncoderModel(vision_model=vision_model,
                                           text_model=text_model)
        model.to(torch_device)
        model.eval()

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask,
                       output_attentions=True)

        vision_attentions = output.vision_model_output.attentions
        self.assertEqual(len(vision_attentions),
                         vision_config.num_hidden_layers)

        # in DEiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
        image_size = to_2tuple(vision_model.config.image_size)
        patch_size = to_2tuple(vision_model.config.patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] //
                                                          patch_size[0])
        seq_len = num_patches + 2
        self.assertEqual(vision_attentions[0].shape[-3:],
                         (vision_config.num_attention_heads, seq_len, seq_len))

        text_attentions = output.text_model_output.attentions
        self.assertEqual(len(text_attentions), text_config.num_hidden_layers)

        self.assertEqual(
            text_attentions[0].shape[-3:],
            (text_config.num_attention_heads, input_ids.shape[-1],
             input_ids.shape[-1]),
        )
Exemplo n.º 9
0
    def check_model_from_pretrained_configs(self,
                                            text_config,
                                            input_ids,
                                            attention_mask,
                                            vision_config,
                                            pixel_values=None,
                                            **kwargs):
        config = VisionTextDualEncoderConfig.from_vision_text_configs(
            vision_config, text_config)

        model = VisionTextDualEncoderModel(config)
        model.to(torch_device)
        model.eval()

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape,
                         (input_ids.shape[0], config.projection_dim))
        self.assertEqual(output["image_embeds"].shape,
                         (pixel_values.shape[0], config.projection_dim))
Exemplo n.º 10
0
    def check_save_load(self,
                        text_config,
                        input_ids,
                        attention_mask,
                        vision_config,
                        pixel_values=None,
                        **kwargs):
        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        model = VisionTextDualEncoderModel(vision_model=vision_model,
                                           text_model=text_model)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(input_ids=input_ids,
                           pixel_values=pixel_values,
                           attention_mask=attention_mask)
            out_1 = output[0].cpu().numpy()

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model = VisionTextDualEncoderModel.from_pretrained(
                    tmpdirname).eval()
                model.to(torch_device)

                after_output = model(input_ids=input_ids,
                                     pixel_values=pixel_values,
                                     attention_mask=attention_mask)
                out_2 = after_output[0].cpu().numpy()
                max_diff = np.amax(np.abs(out_2 - out_1))
                self.assertLessEqual(max_diff, 1e-5)