def test_load_with_mismatched_shapes(self): if not self.test_mismatched_shapes: return config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: if model_class not in get_values( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): continue with self.subTest(msg=f"Testing {model_class}"): with tempfile.TemporaryDirectory() as tmp_dir: model = model_class(config) model.save_pretrained(tmp_dir) # Fails when we don't set ignore_mismatched_sizes=True with self.assertRaises(ValueError): new_model = FlaxAutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42) with self.assertRaises(ValueError): new_model_without_prefix = FlaxAutoModel.from_pretrained( tmp_dir, vocab_size=10) logger = logging.get_logger( "transformers.modeling_flax_utils") with CaptureLogger(logger) as cl: new_model = FlaxAutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42, ignore_mismatched_sizes=True) self.assertIn("the shapes did not match", cl.out) logits = new_model(**inputs_dict)["logits"] self.assertEqual(logits.shape[1], 42) with CaptureLogger(logger) as cl: new_model_without_prefix = FlaxAutoModel.from_pretrained( tmp_dir, vocab_size=10, ignore_mismatched_sizes=True) self.assertIn("the shapes did not match", cl.out) input_ids = ids_tensor((2, 8), 10) if self.is_encoder_decoder: new_model_without_prefix(input_ids, decoder_input_ids=input_ids) else: new_model_without_prefix(input_ids)
def from_text_vision_pretrained( cls, text_model_name_or_path: str = None, vision_model_name_or_path: str = None, *model_args, **kwargs, ) -> FlaxPreTrainedModel: """ Params: text_model_name_or_path (:obj: `str`, `optional`): Information necessary to initiate the text model. Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model using the provided conversion scripts and loading the Flax model afterwards. vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): Information necessary to initiate the vision model. Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model using the provided conversion scripts and loading the Flax model afterwards. model_args (remaining positional arguments, `optional`): All remaning positional arguments will be passed to the underlying model's ``__init__`` method. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attentions=True`). - To update the text configuration, use the prefix `text_` for each configuration parameter. - To update the vision configuration, use the prefix `vision_` for each configuration parameter. - To update the parent model configuration, do not use a prefix for each configuration parameter. Behaves differently depending on whether a :obj:`config` is provided or automatically loaded. Example:: >>> from transformers import FlaxHybridCLIP >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized. >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32') >>> # saving model after fine-tuning >>> model.save_pretrained("./bert-clip") >>> # load fine-tuned model >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip") """ kwargs_text = { argument[len("text_"):]: value for argument, value in kwargs.items() if argument.startswith("text_") } kwargs_vision = { argument[len("vision_"):]: value for argument, value in kwargs.items() if argument.startswith("vision_") } # remove text, vision kwargs from kwargs for key in kwargs_text.keys(): del kwargs["text_" + key] for key in kwargs_vision.keys(): del kwargs["vision_" + key] # Load and initialize the text and vision model text_model = kwargs_text.pop("model", None) if text_model is None: assert ( text_model_name_or_path is not None ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined" from transformers import FlaxAutoModel if "config" not in kwargs_text: from transformers import AutoConfig text_config = AutoConfig.from_pretrained( text_model_name_or_path) kwargs_text["config"] = text_config text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) vision_model = kwargs_vision.pop("model", None) if vision_model is None: assert ( vision_model_name_or_path is not None ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined" from transformers import FlaxAutoModel if "config" not in kwargs_vision: from transformers import AutoConfig vision_config = AutoConfig.from_pretrained( vision_model_name_or_path) kwargs_vision["config"] = vision_config vision_model = FlaxAutoModel.from_pretrained( vision_model_name_or_path, *model_args, **kwargs_vision) # instantiate config with corresponding kwargs dtype = kwargs.pop("dtype", jnp.float32) config = HybridCLIPConfig.from_text_vision_configs( text_model.config, vision_model.config, **kwargs) # init model model = cls(config, *model_args, dtype=dtype, **kwargs) if vision_config.model_type == "clip": model.params["vision_model"]["vision_model"] = vision_model.params[ "vision_model"] model.params["visual_projection"]["kernel"] = vision_model.params[ "visual_projection"]["kernel"] else: model.params["vision_model"] = vision_model.params model.params["text_model"] = text_model.params return model