示例#1
0
    def check_save_load(self,
                        text_config,
                        input_ids,
                        attention_mask,
                        vision_config,
                        pixel_values=None,
                        **kwargs):
        vision_model, text_model = self.get_vision_text_model(
            vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
            **kwargs)

        output = model(input_ids=input_ids,
                       pixel_values=pixel_values,
                       attention_mask=attention_mask)
        out_1 = output[0]

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
            model = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname)

            after_output = model(input_ids=input_ids,
                                 pixel_values=pixel_values,
                                 attention_mask=attention_mask)
            out_2 = after_output[0]
            max_diff = np.amax(np.abs(out_2 - out_1))
            self.assertLessEqual(max_diff, 1e-3)
示例#2
0
    def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):

        config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

        pt_model = VisionTextDualEncoderModel(config)
        fx_model = FlaxVisionTextDualEncoderModel(config)

        fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
        fx_model.params = fx_state

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
示例#3
0
    def check_vision_text_output_attention(
        self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
    ):
        vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)

        output = model(
            input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True
        )

        vision_attentions = output.vision_model_output.attentions
        self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers)

        # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
        image_size = to_2tuple(vision_model.config.image_size)
        patch_size = to_2tuple(vision_model.config.patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        seq_len = num_patches + 1
        self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len))

        text_attentions = output.text_model_output.attentions
        self.assertEqual(len(text_attentions), text_config.num_hidden_layers)

        self.assertEqual(
            text_attentions[0].shape[-3:],
            (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
        )
示例#4
0
    def test_inference(self):
        model = FlaxVisionTextDualEncoderModel.from_pretrained(
            "clip-italian/clip-italian", logit_scale_init_value=1)
        processor = VisionTextDualEncoderProcessor.from_pretrained(
            "clip-italian/clip-italian")

        image = Image.open(
            "./tests/fixtures/tests_samples/COCO/000000039769.png")
        inputs = processor(
            text=["una foto di un gatto", "una foto di un cane"],
            images=image,
            padding=True,
            return_tensors="np")

        outputs = model(**inputs)

        # verify the logits
        self.assertEqual(
            outputs.logits_per_image.shape,
            (inputs.pixel_values.shape[0], inputs.input_ids.shape[0]))
        self.assertEqual(
            outputs.logits_per_text.shape,
            (inputs.input_ids.shape[0], inputs.pixel_values.shape[0]),
        )

        expected_logits = np.array([[1.2284727, 0.3104122]])

        self.assertTrue(
            np.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
示例#5
0
    def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids,
                                  attention_mask, pixel_values, **kwargs):

        pt_model.to(torch_device)
        pt_model.eval()

        # prepare inputs
        inputs_dict = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values
        }
        pt_inputs = inputs_dict
        flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}

        with torch.no_grad():
            pt_outputs = pt_model(**pt_inputs).to_tuple()

        fx_outputs = fx_model(**flax_inputs).to_tuple()
        self.assertEqual(len(fx_outputs), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
            self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)

        # PT -> Flax
        with tempfile.TemporaryDirectory() as tmpdirname:
            pt_model.save_pretrained(tmpdirname)
            fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(
                tmpdirname, from_pt=True)

        fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
        self.assertEqual(len(fx_outputs_loaded), len(pt_outputs),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4],
                                               pt_outputs[:4]):
            self.assert_almost_equals(fx_output_loaded, pt_output.numpy(),
                                      4e-2)

        # Flax -> PT
        with tempfile.TemporaryDirectory() as tmpdirname:
            fx_model.save_pretrained(tmpdirname)
            pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(
                tmpdirname, from_flax=True)

        pt_model_loaded.to(torch_device)
        pt_model_loaded.eval()

        with torch.no_grad():
            pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

        self.assertEqual(len(fx_outputs), len(pt_outputs_loaded),
                         "Output lengths differ between Flax and PyTorch")
        for fx_output, pt_output_loaded in zip(fx_outputs[:4],
                                               pt_outputs_loaded[:4]):
            self.assert_almost_equals(fx_output, pt_output_loaded.numpy(),
                                      4e-2)
示例#6
0
    def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):

        config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

        pt_model = VisionTextDualEncoderModel(config)
        fx_model = FlaxVisionTextDualEncoderModel(config)

        pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

        self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
示例#7
0
    def check_model_from_pretrained_configs(
        self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
    ):
        config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

        model = FlaxVisionTextDualEncoderModel(config)

        output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], config.projection_dim))
        self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], config.projection_dim))
示例#8
0
    def check_vision_text_dual_encoder_from_pretrained(
        self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
    ):

        vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
        kwargs = {"vision_model": vision_model, "text_model": text_model}
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)

        output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)

        self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim))
        self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim))
示例#9
0
    def test_real_model_save_load_from_pretrained(self):
        model_2, inputs = self.get_pretrained_model_and_inputs()

        outputs = model_2(**inputs)
        out_2 = outputs[0]

        with tempfile.TemporaryDirectory() as tmp_dirname:
            model_2.save_pretrained(tmp_dirname)
            model_1 = FlaxVisionTextDualEncoderModel.from_pretrained(tmp_dirname)

            after_outputs = model_1(**inputs)
            out_1 = after_outputs[0]
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)
示例#10
0
    def get_pretrained_model_and_inputs(self):
        model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
            "hf-internal-testing/tiny-random-clip",
            "hf-internal-testing/tiny-bert",
            vision_from_pt=True,
            text_from_pt=True,
        )
        batch_size = 13
        pixel_values = floats_tensor(
            [
                batch_size,
                model.config.vision_config.num_channels,
                model.config.vision_config.image_size,
                model.config.vision_config.image_size,
            ]
        )
        input_ids = ids_tensor([batch_size, 4], model.config.text_config.vocab_size)
        attention_mask = random_attention_mask([batch_size, 4])
        inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}

        return model, inputs