def test_load_with_mismatched_shapes(self):
        if not self.test_mismatched_shapes:
            return
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            if model_class not in get_values(
                    FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
                continue

            with self.subTest(msg=f"Testing {model_class}"):
                with tempfile.TemporaryDirectory() as tmp_dir:
                    model = model_class(config)
                    model.save_pretrained(tmp_dir)

                    # Fails when we don't set ignore_mismatched_sizes=True
                    with self.assertRaises(ValueError):
                        new_model = FlaxAutoModelForSequenceClassification.from_pretrained(
                            tmp_dir, num_labels=42)
                    with self.assertRaises(ValueError):
                        new_model_without_prefix = FlaxAutoModel.from_pretrained(
                            tmp_dir, vocab_size=10)

                    logger = logging.get_logger(
                        "transformers.modeling_flax_utils")
                    with CaptureLogger(logger) as cl:
                        new_model = FlaxAutoModelForSequenceClassification.from_pretrained(
                            tmp_dir,
                            num_labels=42,
                            ignore_mismatched_sizes=True)
                    self.assertIn("the shapes did not match", cl.out)

                    logits = new_model(**inputs_dict)["logits"]
                    self.assertEqual(logits.shape[1], 42)

                    with CaptureLogger(logger) as cl:
                        new_model_without_prefix = FlaxAutoModel.from_pretrained(
                            tmp_dir,
                            vocab_size=10,
                            ignore_mismatched_sizes=True)
                    self.assertIn("the shapes did not match", cl.out)
                    input_ids = ids_tensor((2, 8), 10)
                    if self.is_encoder_decoder:
                        new_model_without_prefix(input_ids,
                                                 decoder_input_ids=input_ids)
                    else:
                        new_model_without_prefix(input_ids)
    def from_text_vision_pretrained(
        cls,
        text_model_name_or_path: str = None,
        vision_model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ) -> FlaxPreTrainedModel:
        """
        Params:
            text_model_name_or_path (:obj: `str`, `optional`):
                Information necessary to initiate the text model. Can be either:

                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
                      a Flax model using the provided conversion scripts and loading the Flax model afterwards.

            vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
                Information necessary to initiate the vision model. Can be either:

                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
                      a Flax model using the provided conversion scripts and loading the Flax model afterwards.

            model_args (remaining positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.

            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                :obj:`output_attentions=True`).

                - To update the text configuration, use the prefix `text_` for each configuration parameter.
                - To update the vision configuration, use the prefix `vision_` for each configuration parameter.
                - To update the parent model configuration, do not use a prefix for each configuration parameter.

                Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.

        Example::

            >>> from transformers import FlaxHybridCLIP
            >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
            >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
            >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
            >>> # saving model after fine-tuning
            >>> model.save_pretrained("./bert-clip")
            >>> # load fine-tuned model
            >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
        """

        kwargs_text = {
            argument[len("text_"):]: value
            for argument, value in kwargs.items()
            if argument.startswith("text_")
        }

        kwargs_vision = {
            argument[len("vision_"):]: value
            for argument, value in kwargs.items()
            if argument.startswith("vision_")
        }

        # remove text, vision kwargs from kwargs
        for key in kwargs_text.keys():
            del kwargs["text_" + key]
        for key in kwargs_vision.keys():
            del kwargs["vision_" + key]

        # Load and initialize the text and vision model
        text_model = kwargs_text.pop("model", None)
        if text_model is None:
            assert (
                text_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
            from transformers import FlaxAutoModel

            if "config" not in kwargs_text:
                from transformers import AutoConfig

                text_config = AutoConfig.from_pretrained(
                    text_model_name_or_path)
                kwargs_text["config"] = text_config

            text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path,
                                                       *model_args,
                                                       **kwargs_text)

        vision_model = kwargs_vision.pop("model", None)
        if vision_model is None:
            assert (
                vision_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
            from transformers import FlaxAutoModel

            if "config" not in kwargs_vision:
                from transformers import AutoConfig

                vision_config = AutoConfig.from_pretrained(
                    vision_model_name_or_path)
                kwargs_vision["config"] = vision_config

            vision_model = FlaxAutoModel.from_pretrained(
                vision_model_name_or_path, *model_args, **kwargs_vision)

        # instantiate config with corresponding kwargs
        dtype = kwargs.pop("dtype", jnp.float32)
        config = HybridCLIPConfig.from_text_vision_configs(
            text_model.config, vision_model.config, **kwargs)

        # init model
        model = cls(config, *model_args, dtype=dtype, **kwargs)

        if vision_config.model_type == "clip":
            model.params["vision_model"]["vision_model"] = vision_model.params[
                "vision_model"]
            model.params["visual_projection"]["kernel"] = vision_model.params[
                "visual_projection"]["kernel"]
        else:
            model.params["vision_model"] = vision_model.params

        model.params["text_model"] = text_model.params

        return model