def test_pt_tf_model_equivalence(self): if not is_torch_available(): return import torch import transformers config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: pt_model_class_name = model_class.__name__[ 2:] # Skip the "TF" at the beggining pt_model_class = getattr(transformers, pt_model_class_name) config.output_hidden_states = True tf_model = model_class(config) pt_model = pt_model_class(config) # Check we can load pt model in tf and vice-versa with model => model functions tf_model = transformers.load_pytorch_model_in_tf2_model( tf_model, pt_model, tf_inputs=inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model( pt_model, tf_model) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() pt_inputs_dict = dict( (name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()) with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(inputs_dict) max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) self.assertLessEqual(max_diff, 2e-2) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with TemporaryDirectory() as tmpdirname: pt_checkpoint_path = os.path.join(tmpdirname, 'pt_model.bin') torch.save(pt_model.state_dict(), pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( tf_model, pt_checkpoint_path) tf_checkpoint_path = os.path.join(tmpdirname, 'tf_model.h5') tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( pt_model, tf_checkpoint_path) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() pt_inputs_dict = dict( (name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()) with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(inputs_dict) max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) self.assertLessEqual(max_diff, 2e-2)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) To train the model, you should first set it back in training mode with ``model.train()`` The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. state_dict: (`optional`) dict: an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ config = kwargs.pop('config', None) state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) # Load config if config is None: config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, proxies=proxies, **kwargs) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_tf` set to False" .format([ WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index" ], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format( pretrained_model_name_or_path) archive_file = pretrained_model_name_or_path + ".index" # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: msg = "Couldn't reach server at '{}' to download pretrained weights.".format( archive_file) else: msg = "Model name '{}' was not found in model name list ({}). " \ "We assumed '{}' was a path or url to model weight files named one of {} but " \ "couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file, [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: if resolved_archive_file.endswith('.index'): # Load from a TensorFlow 1.X checkpoint - provided by original authors model = cls.load_tf_weights( model, config, resolved_archive_file[:-6]) # Remove the '.index' else: # Load from our TensorFlow 2.0 checkpoints try: from transformers import load_tf2_checkpoint_in_pytorch_model model = load_tf2_checkpoint_in_pytorch_model( model, resolved_archive_file, allow_missing_keys=True) except ImportError as e: logger.error( "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." ) raise e else: # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') # Make sure we are able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model if not hasattr(model, cls.base_model_prefix) and any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + '.' if hasattr(model, cls.base_model_prefix) and not any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( 'Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) if hasattr(model, 'tie_weights'): model.tie_weights( ) # make sure word embedding weights are still tied # Set model in evaluation mode to desactivate DropOut modules by default model.eval() if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs } return model, loading_info return model
def test_pt_tf_model_equivalence(self): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( return_obj_labels="PreTraining" in model_class.__name__) tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning if not hasattr(transformers, tf_model_class_name): # transformers does not have TF version yet return tf_model_class = getattr(transformers, tf_model_class_name) config.output_hidden_states = True config.task_obj_predict = False pt_model = model_class(config) tf_model = tf_model_class(config) # Check we can load pt model in tf and vice-versa with model => model functions pt_inputs = self._prepare_for_class(inputs_dict, model_class) def recursive_numpy_convert(iterable): return_dict = {} for key, value in iterable.items(): if type(value) == bool: return_dict[key] = value if isinstance(value, dict): return_dict[key] = recursive_numpy_convert(value) else: if isinstance(value, (list, tuple)): return_dict[key] = (tf.convert_to_tensor( iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value) else: return_dict[key] = tf.convert_to_tensor( value.cpu().numpy(), dtype=tf.int32) return return_dict tf_inputs_dict = recursive_numpy_convert(pt_inputs) tf_model = transformers.load_pytorch_model_in_tf2_model( tf_model, pt_model, tf_inputs=tf_inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model( pt_model, tf_model).to(torch_device) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() # Delete obj labels as we want to compute the hidden states and not the loss if "obj_labels" in inputs_dict: del inputs_dict["obj_labels"] pt_inputs = self._prepare_for_class(inputs_dict, model_class) tf_inputs_dict = recursive_numpy_convert(pt_inputs) with torch.no_grad(): pto = pt_model(**pt_inputs) tfo = tf_model(tf_inputs_dict, training=False) tf_hidden_states = tfo[0].numpy() pt_hidden_states = pto[0].cpu().numpy() tf_nans = np.copy(np.isnan(tf_hidden_states)) pt_nans = np.copy(np.isnan(pt_hidden_states)) pt_hidden_states[tf_nans] = 0 tf_hidden_states[tf_nans] = 0 pt_hidden_states[pt_nans] = 0 tf_hidden_states[pt_nans] = 0 max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) # Debug info (remove when fixed) if max_diff >= 2e-2: print("===") print(model_class) print(config) print(inputs_dict) print(pt_inputs) self.assertLessEqual(max_diff, 6e-2) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") torch.save(pt_model.state_dict(), pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( tf_model, pt_checkpoint_path) tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( pt_model, tf_checkpoint_path) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() for key, value in pt_inputs.items(): if key in ("visual_feats", "visual_pos"): pt_inputs[key] = value.to(torch.float32) else: pt_inputs[key] = value.to(torch.long) with torch.no_grad(): pto = pt_model(**pt_inputs) tfo = tf_model(tf_inputs_dict) tfo = tfo[0].numpy() pto = pto[0].cpu().numpy() tf_nans = np.copy(np.isnan(tfo)) pt_nans = np.copy(np.isnan(pto)) pto[tf_nans] = 0 tfo[tf_nans] = 0 pto[pt_nans] = 0 tfo[pt_nans] = 0 max_diff = np.amax(np.abs(tfo - pto)) self.assertLessEqual(max_diff, 6e-2)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """Instantiate a pretrained pytorch model from a pre-trained model configuration. The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) To train the model, you should first set it back in training mode with ``model.train()`` The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. state_dict: (`optional`) dict: an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ config = kwargs.pop('config', None) state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) random_init = kwargs.pop("random_init", False) use_cdn = kwargs.pop("use_cdn", True) local_files_only = kwargs.pop("local_files_only", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) kwargs_config = kwargs.copy() mapping_keys_state_dic = kwargs.pop("mapping_keys_state_dic", None) kwargs_config.pop("mapping_keys_state_dic", None) if config is None: config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, **kwargs_config) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_tf` set to False" .format( [ WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index" ], pretrained_model_name_or_path, )) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): assert ( from_tf ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( pretrained_model_name_or_path + ".index") archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), use_cdn=use_cdn, ) try: # Load from URL or cache if already cached resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) if resolved_archive_file is None: raise EnvironmentError except EnvironmentError: msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: if resolved_archive_file.endswith('.index'): # Load from a TensorFlow 1.X checkpoint - provided by original authors model = cls.load_tf_weights( model, config, resolved_archive_file[:-6]) # Remove the '.index' else: # Load from our TensorFlow 2.0 checkpoints try: from transformers import load_tf2_checkpoint_in_pytorch_model model = load_tf2_checkpoint_in_pytorch_model( model, resolved_archive_file, allow_missing_keys=True) except ImportError as e: logger.error( "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." ) raise e else: # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # assert mapping_keys_state_dic is not None, "ERROR did not found mapping dicts for {} ".format(pretrained_model_name_or_path) # mapping_keys_state_dic = {"roberta": "encoder", "lm_head": "head.mlm"} if mapping_keys_state_dic is not None: assert isinstance(mapping_keys_state_dic, dict), "ERROR " print( "INFO : from loading from pretrained method (assuming loading original google model : " "need to rename some keys {})".format( mapping_keys_state_dic)) state_dict = cls.adapt_state_dic_to_multitask( state_dict, keys_mapping=mapping_keys_state_dic, add_prefix=pretrained_model_name_or_path == "asafaya/bert-base-arabic") #pdb.set_trace() def load(module, prefix=''): local_metadata = {"version": 1} if not prefix.startswith("head") or prefix.startswith( "head.mlm"): assert len( missing_keys ) == 0, "ERROR {} missing keys in state_dict {}".format( prefix, missing_keys) else: if len(missing_keys) == 0: print( "Warning {} missing keys in state_dict {} (warning expected for task-specific fine-tuning)" .format(prefix, missing_keys)) module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): # load_params_only_ls = kwargs.get("load_params_only_ls ") not_load_params_ls = kwargs.get( "not_load_params_ls") if kwargs.get( "not_load_params_ls") is not None else [] assert isinstance( not_load_params_ls, list ), f"Argument error not_load_params_ls should be a list but is {not_load_params_ls}" matching_not_load = [] # RANDOM-INIT for pattern in not_load_params_ls: matching = re.match(pattern, prefix + name) if matching is not None: matching_not_load.append(matching) if len(matching_not_load) > 0: # means there is at least one patter in not load pattern that matched --> so should load print("MATCH not loading : {} parameters {} ".format( prefix + name, not_load_params_ls)) if child is not None and len(matching_not_load) == 0: #print("MODEL loading : child {} full {} ".format(name, prefix + name + '.')) load(child, prefix + name + '.') else: print( "MODEL not loading : child {} matching_not_load {} " .format(child, matching_not_load)) # Make sure we are able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model if not hasattr(model, cls.base_model_prefix) and any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + '.' if hasattr(model, cls.base_model_prefix) and not any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) if not random_init: load(model_to_load, prefix=start_prefix) else: print("WARNING : RANDOM INTIALIZATION OF BERTMULTITASK") if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( 'Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) if hasattr(model, 'tie_weights'): model.tie_weights( ) # make sure word embedding weights are still tied # Set model in evaluation mode to desactivate DropOut modules by default model.eval() if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs } return model, loading_info return model
def from_pretrained(cls, pretrained_model_name_or_path, eigvecs_dict, *model_args, **kwargs): config = kwargs.pop("config", None) state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) from_tf = kwargs.pop("from_tf", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_tf` set to False" .format( [ WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index" ], pretrained_model_name_or_path, )) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): assert ( from_tf ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( pretrained_model_name_or_path + ".index") archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url( pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: msg = "Couldn't reach server at '{}' to download pretrained weights.".format( archive_file) else: msg = ( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url to model weight files named one of {} but " "couldn't find any such file at this path or url.". format( pretrained_model_name_or_path, ", ".join(cls.pretrained_model_archive_map.keys()), archive_file, [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME], )) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, eigvecs_dict, *model_args, **model_kwargs) if state_dict is None and not from_tf: try: state_dict = torch.load(resolved_archive_file, map_location="cpu") except Exception: raise OSError( "Unable to load weights from pytorch checkpoint file. " "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " ) missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: if resolved_archive_file.endswith(".index"): # Load from a TensorFlow 1.X checkpoint - provided by original authors model = cls.load_tf_weights( model, config, resolved_archive_file[:-6]) # Remove the '.index' else: # Load from our TensorFlow 2.0 checkpoints try: from transformers import load_tf2_checkpoint_in_pytorch_model model = load_tf2_checkpoint_in_pytorch_model( model, resolved_archive_file, allow_missing_keys=True) except ImportError: logger.error( "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." ) raise else: # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if "gamma" in key: new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: nn.Module, prefix=""): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model if not hasattr(model, cls.base_model_prefix) and any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + "." if hasattr(model, cls.base_model_prefix) and not any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if model.__class__.__name__ != model_to_load.__class__.__name__: base_model_state_dict = model_to_load.state_dict().keys() head_model_state_dict_without_base_prefix = [ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() ] missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) model.tie_weights( ) # make sure token embedding weights are still tied if needed # Set model in evaluation mode to desactivate DropOut modules by default model.eval() if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs, } return model, loading_info return model
def test_pt_tf_model_equivalence(self): if not is_torch_available(): return import torch import transformers config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: pt_model_class_name = model_class.__name__[ 2:] # Skip the "TF" at the beggining pt_model_class = getattr(transformers, pt_model_class_name) config.output_hidden_states = True tf_model = model_class(config) pt_model = pt_model_class(config) # Check we can load pt model in tf and vice-versa with model => model functions tf_model = transformers.load_pytorch_model_in_tf2_model( tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)) pt_model = transformers.load_tf2_model_in_pytorch_model( pt_model, tf_model) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() pt_inputs_dict = dict( (name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in self._prepare_for_class( inputs_dict, model_class).items()) # need to rename encoder-decoder "inputs" for PyTorch if "inputs" in pt_inputs_dict and self.is_encoder_decoder: pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) tf_hidden_states = tfo[0].numpy() pt_hidden_states = pto[0].numpy() tf_nans = np.copy(np.isnan(tf_hidden_states)) pt_nans = np.copy(np.isnan(pt_hidden_states)) pt_hidden_states[tf_nans] = 0 tf_hidden_states[tf_nans] = 0 pt_hidden_states[pt_nans] = 0 tf_hidden_states[pt_nans] = 0 max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) # Debug info (remove when fixed) if max_diff >= 2e-2: print("===") print(model_class) print(config) print(inputs_dict) print(pt_inputs_dict) self.assertLessEqual(max_diff, 2e-2) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") torch.save(pt_model.state_dict(), pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( tf_model, pt_checkpoint_path) tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( pt_model, tf_checkpoint_path) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() pt_inputs_dict = dict( (name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in self._prepare_for_class( inputs_dict, model_class).items()) # need to rename encoder-decoder "inputs" for PyTorch if "inputs" in pt_inputs_dict and self.is_encoder_decoder: pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class)) tfo = tfo[0].numpy() pto = pto[0].numpy() tf_nans = np.copy(np.isnan(tfo)) pt_nans = np.copy(np.isnan(pto)) pto[tf_nans] = 0 tfo[tf_nans] = 0 pto[pt_nans] = 0 tfo[pt_nans] = 0 max_diff = np.amax(np.abs(tfo - pto)) self.assertLessEqual(max_diff, 2e-2)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = kwargs.pop("config", None) state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) from_tf = kwargs.pop("from_tf", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_cdn = kwargs.pop("use_cdn", True) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_tf` set to False" .format( [ WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index" ], pretrained_model_name_or_path, )) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): assert (from_tf), ( "We found a TensorFlow checkpoint at {}, please set from_tf to True" " to load from this checkpoint" ).format(pretrained_model_name_or_path + ".index") archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), use_cdn=use_cdn, ) try: # Load from URL or cache if already cached resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) if resolved_archive_file is None: raise EnvironmentError except EnvironmentError: msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make " f"sure that:\n\n- '{pretrained_model_name_or_path}' is a correct " f"model identifier listed on 'https://huggingface.co/models'\n\n- " f"or '{pretrained_model_name_or_path}' is the correct path to a " f"directory containing a file named one of {WEIGHTS_NAME}, " f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n") raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: try: state_dict = torch.load(resolved_archive_file, map_location="cpu") except Exception: raise OSError( "Unable to load weights from pytorch checkpoint file. " "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " ) missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: if resolved_archive_file.endswith(".index"): # Load from a TensorFlow 1.X checkpoint - provided by original authors model = cls.load_tf_weights( model, config, resolved_archive_file[:-6]) # Remove the '.index' else: # Load from our TensorFlow 2.0 checkpoints try: from transformers import load_tf2_checkpoint_in_pytorch_model model = load_tf2_checkpoint_in_pytorch_model( model, resolved_archive_file, allow_missing_keys=True) except ImportError: logger.error( "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." ) raise else: # Convert old format to new format if needed from a PyTorch state_dict has_all_sub_modules = all( any(s.startswith(_prefix) for s in state_dict.keys()) for _prefix in model.base_model_prefixs) has_prefix_module = any( s.startswith(model.base_model_prefix) for s in state_dict.keys()) old_keys = list(state_dict.keys()) for key in old_keys: new_key = key if "gamma" in key: new_key = new_key.replace("gamma", "weight") if "beta" in key: new_key = new_key.replace("beta", "bias") _state = state_dict.pop(key) if has_all_sub_modules: state_dict[new_key] = _state elif not has_prefix_module: for _prefix in model.base_model_prefixs: _key = _prefix + "." + new_key state_dict[_key] = _state else: if new_key.startswith(model.base_model_prefix): for _prefix in model.base_model_prefixs: _key = _prefix + new_key[len(model. base_model_prefix):] state_dict[_key] = _state else: state_dict[new_key] = _state if hasattr(model, "hack_pretrained_state_dict"): state_dict = model.hack_pretrained_state_dict(state_dict) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") # Make sure we are able to load base models as well as derived models (with heads) load(model) load = None if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) model.tie_weights( ) # make sure token embedding weights are still tied if needed # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs, } return model, loading_info if hasattr(config, "xla_device") and config.xla_device: import torch_xla.core.xla_model as xm model = xm.send_cpu_data_to_device(model, xm.xla_device()) model.to(xm.xla_device()) return model
def test_pt_tf_model_equivalence(self): import numpy as np import tensorflow as tf import transformers config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning if not hasattr(transformers, tf_model_class_name): # transformers does not have TF version yet return tf_model_class = getattr(transformers, tf_model_class_name) config.output_hidden_states = True tf_model = tf_model_class(config) pt_model = model_class(config) # make sure only tf inputs are forward that actually exist in function args tf_input_keys = set( inspect.signature(tf_model.call).parameters.keys()) # remove all head masks tf_input_keys.discard("head_mask") tf_input_keys.discard("cross_attn_head_mask") tf_input_keys.discard("decoder_head_mask") pt_inputs = self._prepare_for_class(inputs_dict, model_class) pt_inputs = { k: v for k, v in pt_inputs.items() if k in tf_input_keys } # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() tf_inputs_dict = {} for key, tensor in pt_inputs.items(): # skip key that does not exist in tf if type(tensor) == bool: tf_inputs_dict[key] = tensor elif key == "input_values": tf_inputs_dict[key] = tf.convert_to_tensor( tensor.numpy(), dtype=tf.float32) elif key == "pixel_values": tf_inputs_dict[key] = tf.convert_to_tensor( tensor.numpy(), dtype=tf.float32) else: tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) # Check we can load pt model in tf and vice-versa with model => model functions tf_model = transformers.load_pytorch_model_in_tf2_model( tf_model, pt_model, tf_inputs=tf_inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model( pt_model, tf_model) # need to rename encoder-decoder "inputs" for PyTorch # if "inputs" in pt_inputs_dict and self.is_encoder_decoder: # pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") with torch.no_grad(): pto = pt_model(**pt_inputs) tfo = tf_model(tf_inputs_dict, training=False) self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch") for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()): if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)): continue tf_out = tf_output.numpy() pt_out = pt_output.numpy() self.assertEqual( tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch") if len(tf_out.shape) > 0: tf_nans = np.copy(np.isnan(tf_out)) pt_nans = np.copy(np.isnan(pt_out)) pt_out[tf_nans] = 0 tf_out[tf_nans] = 0 pt_out[pt_nans] = 0 tf_out[pt_nans] = 0 max_diff = np.amax(np.abs(tf_out - pt_out)) self.assertLessEqual(max_diff, 4e-2) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") torch.save(pt_model.state_dict(), pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( tf_model, pt_checkpoint_path) tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( pt_model, tf_checkpoint_path) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() tf_inputs_dict = {} for key, tensor in pt_inputs.items(): # skip key that does not exist in tf if type(tensor) == bool: tensor = np.array(tensor, dtype=bool) tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32) elif key == "input_values": tf_inputs_dict[key] = tf.convert_to_tensor( tensor.numpy(), dtype=tf.float32) elif key == "pixel_values": tf_inputs_dict[key] = tf.convert_to_tensor( tensor.numpy(), dtype=tf.float32) else: tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) # need to rename encoder-decoder "inputs" for PyTorch # if "inputs" in pt_inputs_dict and self.is_encoder_decoder: # pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") with torch.no_grad(): pto = pt_model(**pt_inputs) tfo = tf_model(tf_inputs_dict) self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch") for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()): if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)): continue tf_out = tf_output.numpy() pt_out = pt_output.numpy() self.assertEqual( tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch") if len(tf_out.shape) > 0: tf_nans = np.copy(np.isnan(tf_out)) pt_nans = np.copy(np.isnan(pt_out)) pt_out[tf_nans] = 0 tf_out[tf_nans] = 0 pt_out[pt_nans] = 0 tf_out[pt_nans] = 0 max_diff = np.amax(np.abs(tf_out - pt_out)) self.assertLessEqual(max_diff, 4e-2)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = kwargs.pop("config", None) state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) from_tf = kwargs.pop("from_tf", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) if not isinstance(config, BertConfig.PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, **kwargs, ) else: model_kwargs = kwargs if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_tf` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path, ) ) elif os.path.isfile(pretrained_model_name_or_path) or urlparse(pretrained_model_name_or_path) in ["http", "https", "s3"]: archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): assert ( from_tf ), "We found a TensorFlow checkpoint at {}".format( pretrained_model_name_or_path + ".index" ) archive_file = pretrained_model_name_or_path + ".index" else: archive_file="/".join((S3_BUCKET_PREFIX, pretrained_model_name_or_path, (TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME))) try: resolved_archive_file = BertConfig.cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) except EnvironmentError: raise EnvironmentError if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) else: resolved_archive_file = None model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: try: state_dict = torch.load(resolved_archive_file, map_location="cpu") except Exception: raise OSError( "Unable to load weights from pytorch checkpoint file. " ) missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: if resolved_archive_file.endswith(".index"): model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) else: try: from transformers import load_tf2_checkpoint_in_pytorch_model model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) except ImportError: raise else: old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if "gamma" in key: new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module: nn.Module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") start_prefix = "" model_to_load = model if not hasattr(model, cls.base_model_prefix) and any( s.startswith(cls.base_model_prefix) for s in state_dict.keys() ): start_prefix = cls.base_model_prefix + "." if hasattr(model, cls.base_model_prefix) and not any( s.startswith(cls.base_model_prefix) for s in state_dict.keys() ): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if model.__class__.__name__ != model_to_load.__class__.__name__: base_model_state_dict = model_to_load.state_dict().keys() head_model_state_dict_without_base_prefix = [ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() ] missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) model.tie_weights() model.eval() if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs, } return model, loading_info return model
def test_pt_tf_model_equivalence(self): from transformers import is_torch_available if not is_torch_available(): return import torch import transformers for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( return_obj_labels="PreTraining" in model_class.__name__ ) pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) config.output_hidden_states = True config.task_obj_predict = False tf_model = model_class(config) pt_model = pt_model_class(config) # Check we can load pt model in tf and vice-versa with model => model functions tf_model = transformers.load_pytorch_model_in_tf2_model( tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class) ) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() # Delete obj labels as we want to compute the hidden states and not the loss if "obj_labels" in inputs_dict: del inputs_dict["obj_labels"] def torch_type(key): if key in ("visual_feats", "visual_pos"): return torch.float32 else: return torch.long def recursive_numpy_convert(iterable): return_dict = {} for key, value in iterable.items(): if isinstance(value, dict): return_dict[key] = recursive_numpy_convert(value) else: if isinstance(value, (list, tuple)): return_dict[key] = ( torch.from_numpy(iter_value.numpy()).to(torch_type(key)) for iter_value in value ) else: return_dict[key] = torch.from_numpy(value.numpy()).to(torch_type(key)) return return_dict pt_inputs_dict = recursive_numpy_convert(self._prepare_for_class(inputs_dict, model_class)) # need to rename encoder-decoder "inputs" for PyTorch if "inputs" in pt_inputs_dict and self.is_encoder_decoder: pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) tf_hidden_states = tfo[0].numpy() pt_hidden_states = pto[0].numpy() import numpy as np tf_nans = np.copy(np.isnan(tf_hidden_states)) pt_nans = np.copy(np.isnan(pt_hidden_states)) pt_hidden_states[tf_nans] = 0 tf_hidden_states[tf_nans] = 0 pt_hidden_states[pt_nans] = 0 tf_hidden_states[pt_nans] = 0 max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) # Debug info (remove when fixed) if max_diff >= 2e-2: print("===") print(model_class) print(config) print(inputs_dict) print(pt_inputs_dict) self.assertLessEqual(max_diff, 6e-2) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: import os pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") torch.save(pt_model.state_dict(), pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() pt_inputs_dict = dict( (name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in self._prepare_for_class(inputs_dict, model_class).items() ) for key, value in pt_inputs_dict.items(): if key in ("visual_feats", "visual_pos"): pt_inputs_dict[key] = value.to(torch.float32) else: pt_inputs_dict[key] = value.to(torch.long) with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class)) tfo = tfo[0].numpy() pto = pto[0].numpy() tf_nans = np.copy(np.isnan(tfo)) pt_nans = np.copy(np.isnan(pto)) pto[tf_nans] = 0 tfo[tf_nans] = 0 pto[pt_nans] = 0 tfo[pt_nans] = 0 max_diff = np.amax(np.abs(tfo - pto)) self.assertLessEqual(max_diff, 6e-2)
def test_pt_tf_model_equivalence(self): import torch import transformers config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) config.output_hidden_states = True tf_model = model_class(config) pt_model = pt_model_class(config) # Check we can load pt model in tf and vice-versa with model => model functions tf_model = transformers.load_pytorch_model_in_tf2_model( tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class) ) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() pt_inputs_dict = {} for name, key in self._prepare_for_class(inputs_dict, model_class).items(): if type(key) == bool: pt_inputs_dict[name] = key elif name == "input_values": pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) elif name == "pixel_values": pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) else: pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) # need to rename encoder-decoder "inputs" for PyTorch if "inputs" in pt_inputs_dict and self.is_encoder_decoder: pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch") for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()): if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)): continue tf_out = tf_output.numpy() pt_out = pt_output.numpy() self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch") if len(tf_out.shape) > 0: tf_nans = np.copy(np.isnan(tf_out)) pt_nans = np.copy(np.isnan(pt_out)) pt_out[tf_nans] = 0 tf_out[tf_nans] = 0 pt_out[pt_nans] = 0 tf_out[pt_nans] = 0 max_diff = np.amax(np.abs(tf_out - pt_out)) self.assertLessEqual(max_diff, 4e-2) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") torch.save(pt_model.state_dict(), pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences pt_model.eval() pt_inputs_dict = {} for name, key in self._prepare_for_class(inputs_dict, model_class).items(): if type(key) == bool: key = np.array(key, dtype=bool) pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) elif name == "input_values": pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) elif name == "pixel_values": pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) else: pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) # need to rename encoder-decoder "inputs" for PyTorch if "inputs" in pt_inputs_dict and self.is_encoder_decoder: pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") with torch.no_grad(): pto = pt_model(**pt_inputs_dict) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class)) self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch") for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()): if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)): continue tf_out = tf_output.numpy() pt_out = pt_output.numpy() self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch") if len(tf_out.shape) > 0: tf_nans = np.copy(np.isnan(tf_out)) pt_nans = np.copy(np.isnan(pt_out)) pt_out[tf_nans] = 0 tf_out[tf_nans] = 0 pt_out[pt_nans] = 0 tf_out[pt_nans] = 0 max_diff = np.amax(np.abs(tf_out - pt_out)) self.assertLessEqual(max_diff, 4e-2)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r""" Instantiate a pretrained pytorch model from a pre-trained model configuration. The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To train the model, you should first set it back in training mode with ``model.train()``. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (:obj:`str`, `optional`): Can be either: - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g., ``bert-base-uncased``. - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g., ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the torch.TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``). model_args (sequence of positional arguments, `optional`): All remaning positional arguments will be passed to the underlying model's ``__init__`` method. config (:obj:`typing.Union[PretrainedConfig, str]`, `optional`): Can be either: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the `shortcut name` string of a pretrained model). - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. state_dict (:obj:`typing.Dict[str, torch.Tensor]`, `optional`): A state dictionary to use instead of a state dictionary loaded from saved weights file. This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. cache_dir (:obj:`str`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a torch.TensorFlow checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (:obj:`typing.Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error messages. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (e.g., not try doanloading the model). use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: from transformers import BertConfig, BertModel # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('bert-base-uncased') # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). model = BertModel.from_pretrained('./test/saved_model/') # Update configuration during loading. model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ config = kwargs.pop("config", None) state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) from_tf = kwargs.pop("from_tf", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_cdn = kwargs.pop("use_cdn", True) model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_tf` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path, ) ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): assert ( from_tf ), "We found a torch.TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( pretrained_model_name_or_path + ".index" ) archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), use_cdn=use_cdn, ) try: # Load from URL or cache if already cached resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) if resolved_archive_file is None: raise EnvironmentError except EnvironmentError: msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: l.logger().info("loading weights file {}".format(archive_file)) else: l.logger().info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: try: state_dict = torch.load(resolved_archive_file, map_location="cpu") except Exception: raise OSError( "Unable to load weights from pytorch checkpoint file. " "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " ) missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: if resolved_archive_file.endswith(".index"): # Load from a torch.TensorFlow 1.X checkpoint - provided by original authors model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' else: # Load from our torch.TensorFlow 2.0 checkpoints try: from transformers import load_tf2_checkpoint_in_pytorch_model model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) except ImportError: l.logger().error( "Loading a torch.TensorFlow model in PyTorch, requires both PyTorch and torch.TensorFlow to be installed. Please see " "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." ) raise else: # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if "gamma" in key: new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: torch.nn.Module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) if not hasattr(model, cls.base_model_prefix) and has_prefix_module: start_prefix = cls.base_model_prefix + "." if hasattr(model, cls.base_model_prefix) and not has_prefix_module: model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if model.__class__.__name__ != model_to_load.__class__.__name__: base_model_state_dict = model_to_load.state_dict().keys() head_model_state_dict_without_base_prefix = [ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() ] missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. if cls.authorized_missing_keys is not None: for pat in cls.authorized_missing_keys: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: l.logger().warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: l.logger().info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: l.logger().warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: l.logger().info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs) ) ) # make sure token embedding weights are still tied if needed model.tie_weights() # typing.Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs, } return model, loading_info if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available(): model = pytorch.xla_model.send_cpu_data_to_device(model, pytorch.xla_model.xla_device()) model.to(pytorch.xla_model.xla_device()) return model
def test_pt_tf_model_equivalence(self): import numpy as np import tensorflow as tf import transformers # make masks reproducible np.random.seed(2) config, _ = self.model_tester.prepare_config_and_inputs_for_common() num_patches = int((config.image_size // config.patch_size)**2) noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches)) pt_noise = torch.from_numpy(noise).to(device=torch_device) tf_noise = tf.constant(noise) def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict): tf_inputs_dict = {} for key, tensor in pt_inputs_dict.items(): tf_inputs_dict[key] = tf.convert_to_tensor( tensor.cpu().numpy(), dtype=tf.float32) return tf_inputs_dict def check_outputs(tf_outputs, pt_outputs, model_class, names): """ Args: model_class: The class of the model that is currently testing. For example, `TFBertModel`, TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make debugging easier and faster. names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs. Currently unused, but in the future, we could use this information to make the error message clearer by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF. """ # Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors. if type(tf_outputs) in [tuple, list]: self.assertEqual(type(tf_outputs), type(pt_outputs)) self.assertEqual(len(tf_outputs), len(pt_outputs)) if type(names) == tuple: for tf_output, pt_output, name in zip( tf_outputs, pt_outputs, names): check_outputs(tf_output, pt_output, model_class, names=name) elif type(names) == str: for idx, (tf_output, pt_output) in enumerate( zip(tf_outputs, pt_outputs)): check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}") else: raise ValueError( f"`names` should be a `tuple` or a string. Got {type(names)} instead." ) elif isinstance(tf_outputs, tf.Tensor): self.assertTrue(isinstance(pt_outputs, torch.Tensor)) tf_outputs = tf_outputs.numpy() if isinstance(tf_outputs, np.float32): tf_outputs = np.array(tf_outputs, dtype=np.float32) pt_outputs = pt_outputs.detach().to("cpu").numpy() tf_nans = np.isnan(tf_outputs) pt_nans = np.isnan(pt_outputs) pt_outputs[tf_nans] = 0 tf_outputs[tf_nans] = 0 pt_outputs[pt_nans] = 0 tf_outputs[pt_nans] = 0 max_diff = np.amax(np.abs(tf_outputs - pt_outputs)) self.assertLessEqual(max_diff, 1e-5) else: raise ValueError( f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead." ) def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict): # we are not preparing a model with labels because of the formation # of the ViT MAE model # send pytorch model to the correct device pt_model.to(torch_device) # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences pt_model.eval() tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict) # send pytorch inputs to the correct device pt_inputs_dict = { k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items() } # Original test: check without `labels` with torch.no_grad(): pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise) tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise) tf_keys = tuple( [k for k, v in tf_outputs.items() if v is not None]) pt_keys = tuple( [k for k, v in pt_outputs.items() if v is not None]) self.assertEqual(tf_keys, pt_keys) check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys) config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning # Output all for aggressive testing config.output_hidden_states = True config.output_attentions = self.has_attentions tf_model_class = getattr(transformers, tf_model_class_name) tf_model = tf_model_class(config) pt_model = model_class(config) # make sure only tf inputs are forward that actually exist in function args tf_input_keys = set( inspect.signature(tf_model.call).parameters.keys()) # remove all head masks tf_input_keys.discard("head_mask") tf_input_keys.discard("cross_attn_head_mask") tf_input_keys.discard("decoder_head_mask") pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class) pt_inputs_dict = { k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys } # Check we can load pt model in tf and vice-versa with model => model functions tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict) tf_model = transformers.load_pytorch_model_in_tf2_model( tf_model, pt_model, tf_inputs=tf_inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model( pt_model, tf_model) check_pt_tf_models(tf_model, pt_model, pt_inputs_dict) # Check we can load pt model in tf and vice-versa with checkpoint => model functions with tempfile.TemporaryDirectory() as tmpdirname: pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") torch.save(pt_model.state_dict(), pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( tf_model, pt_checkpoint_path) tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_model.save_weights(tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( pt_model, tf_checkpoint_path) pt_model = pt_model.to(torch_device) check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)