def test_get_image_features(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        model = FlaxCLIPModel(config)

        @jax.jit
        def model_jitted(pixel_values):
            return model.get_image_features(pixel_values=pixel_values)

        with self.subTest("JIT Enabled"):
            jitted_output = model_jitted(inputs_dict["pixel_values"])

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                output = model_jitted(inputs_dict["pixel_values"])

        self.assertEqual(jitted_output.shape, output.shape)
        self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
    def test_get_text_features(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        model = FlaxCLIPModel(config)

        @jax.jit
        def model_jitted(input_ids, attention_mask, **kwargs):
            return model.get_text_features(input_ids=input_ids,
                                           attention_mask=attention_mask)

        with self.subTest("JIT Enabled"):
            jitted_output = model_jitted(**inputs_dict)

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                output = model_jitted(**inputs_dict)

        self.assertEqual(jitted_output.shape, output.shape)
        self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))