def load(kmdl, path): # load the parts which have identical names and shapes: std->std; lhc-formable->lhc-formable kmdl.load_weights(path, True, True) file0 = file = h5py.File(path, 'r') if 'layer_names' not in file.attrs and 'model_weights' in file: file = file['model_weights'] from tensorflow.python.keras.saving.hdf5_format import _legacy_weights, load_attributes_from_hdf5_group, \ preprocess_weights_for_loading if 'keras_version' in file.attrs: original_keras_version = file.attrs['keras_version'] # .decode('utf8') else: original_keras_version = '1' if 'backend' in file.attrs: original_backend = file.attrs['backend'] # .decode('utf8') else: original_backend = None layer_names = load_attributes_from_hdf5_group(file, 'layer_names') index = {} for layer in kmdl.layers: if layer.name: index.setdefault(layer.name, []).append(layer) # load the remaining parts weight_value_tuples = [] for k, name in enumerate(layer_names): g = file[name] weight_names = load_attributes_from_hdf5_group(g, 'weight_names') weight_values = [ np.asarray(g[weight_name]) for weight_name in weight_names ] layer = index.get(name, []) if len(layer) == 0: continue assert len(layer) == 1 layer = layer[0] if type(layer) in (Conv2dLhcf, Conv2dLhcr): weight_values = preprocess_weights_for_loading( layer, weight_values, original_keras_version, original_backend) wdict = dict(zip(weight_names, weight_values)) symbolic_weights = _legacy_weights(layer) symbol_names = [s.name for s in symbolic_weights] sdict = dict(zip(symbol_names, symbolic_weights)) for pname in Conv2dLhcf.VAR_NAMES[:3]: # symb = [__ for _, __ in sdict.items() if _[:-2].endswith(pname)] symb = [__ for _, __ in sdict.items() if pname in _] # wght = [__ for _, __ in wdict.items() if _[:-2].endswith(pname)] wght = [__ for _, __ in wdict.items() if pname in _] assert len(symb) == 1 and len(wght) <= 1 if len(wght) == 1: weight_value_tuples.append((symb[0], wght[0])) KB.batch_set_value(weight_value_tuples) file0.close()
def load_tf_weights(model, resolved_archive_file): """ Load the TF weights from a H5 file. Args: model (:obj:`tf.keras.models.Model`): The model to load the weights into. resolved_archive_file (:obj:`str`): The location of the H5 file. """ with h5py.File(resolved_archive_file, "r") as f: saved_layer_names = set( hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) weight_value_tuples = [] for layer in model.layers: if layer.name in saved_layer_names: g = f[layer.name] saved_weight_names = hdf5_format.load_attributes_from_hdf5_group( g, "weight_names") symbolic_weights = layer.trainable_weights + layer.non_trainable_weights saved_weight_names_values = {} for weight_name in saved_weight_names: name = "/".join(weight_name.split("/")[1:]) saved_weight_names_values[name] = np.asarray( g[weight_name]) for symbolic_weight in symbolic_weights: splited_layers = symbolic_weight.name.split("/")[1:] symbolic_weight_name = "/".join(splited_layers) if symbolic_weight_name in saved_weight_names_values: saved_weight_value = saved_weight_names_values[ symbolic_weight_name] if K.int_shape( symbolic_weight) != saved_weight_value.shape: try: array = np.reshape( saved_weight_value, K.int_shape(symbolic_weight)) except AssertionError as e: e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape) raise e else: array = saved_weight_value weight_value_tuples.append((symbolic_weight, array)) K.batch_set_value(weight_value_tuples)
def detect_tf_missing_unexpected_layers(model, resolved_archive_file): """ Detect missing and unexpected layers. Args: model (:obj:`tf.keras.models.Model`): The model to load the weights into. resolved_archive_file (:obj:`str`): The location of the H5 file. Returns: Two lists, one for the missing layers, and another one for the unexpected layers. """ missing_layers = [] unexpected_layers = [] with h5py.File(resolved_archive_file, "r") as f: saved_layer_names = set( hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layer_names = set(layer.name for layer in model.layers) missing_layers = list(model_layer_names - saved_layer_names) unexpected_layers = list(saved_layer_names - model_layer_names) for layer in model.layers: if layer.name in saved_layer_names: g = f[layer.name] saved_weight_names = hdf5_format.load_attributes_from_hdf5_group( g, "weight_names") saved_weight_names_set = set( "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names) symbolic_weights = layer.trainable_weights + layer.non_trainable_weights symbolic_weights_names = set( "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights) missing_layers.extend( list(symbolic_weights_names - saved_weight_names_set)) unexpected_layers.extend( list(saved_weight_names_set - symbolic_weights_names)) return missing_layers, unexpected_layers
def load_from_bert_pretrained(cls, config_file, pretrained_model_name='bert-base-uncased', **kwargs): model = cls(config_file, **kwargs) model(model.dummy_inputs, training=False) ckpt_layer_mapping = {} for vind, ckpt_ind in enumerate(model.config.ckpt_layer_mapping.split(',')): ckpt_layer_mapping['layer_._{}'.format(vind)] = 'layer_._{}'.format(ckpt_ind) archive_file = hf_bucket_url(pretrained_model_name, filename=TF2_WEIGHTS_NAME, use_cdn=True) resolved_archive_file = cached_path(archive_file, cache_dir=None, force_download=False, resume_download=False, proxies=None) f = h5py.File(resolved_archive_file, mode='r') layer_names = load_attributes_from_hdf5_group(f, 'layer_names') g = f[layer_names[0]] weight_names = load_attributes_from_hdf5_group(g, 'weight_names') weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] weights_map = {'/'.join(name.split('/')[2:]): i for i, name in enumerate(weight_names)} weight_value_tuples = [] w_names = [] for w in model.layers[0].weights: w_name = '/'.join(w.name.split('/')[3:]) for k in ckpt_layer_mapping: if w_name.find(k): w_name = w_name.replace(k, ckpt_layer_mapping[k]) break if w_name in weights_map and w.shape == weight_values[weights_map[w_name]].shape: w_names.append(w_name) weight_value_tuples.append((w, weight_values[weights_map[w_name]])) logger.info("Loaded %d weights" % (len(w_names))) logger.info("Loaded weights names are: %s" % (", ".join(w_names))) K.batch_set_value(weight_value_tuples) print("Loaded %d weights" % (len(w_names))) print("Loaded weights names are: %s" % (", ".join(w_names))) model(model.dummy_inputs, training=False) return model
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~xz_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch state_dict saved_models file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) one of: - an instance of a class derived from :class:`~xz_transformers.PretrainedConfig`, or - a string valid as input to :func:`~xz_transformers.PretrainedConfig.from_pretrained()` Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~xz_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the saved_models directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. from_pt: (`optional`) boolean, default False: Load the model weights from a PyTorch state_dict saved_models file (see docstring of pretrained_model_name_or_path argument). cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. resume_download: (`optional`) boolean, default False: Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~xz_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: # For tasks purposes. Not runnable. model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config) """ config = kwargs.pop("config", None) from_pt = kwargs.pop("from_pt", False) output_loading_info = kwargs.pop("output_loading_info", False) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, return_unused_kwargs=True, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: archive_file = None if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_pt` set to False" .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): archive_file = pretrained_model_name_or_path + ".index" # 提供的pretrained_model_name_or_path有误 if archive_file is None: raise EnvironmentError else: resolved_archive_file = archive_file else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if from_pt: # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model( model, resolved_archive_file, allow_missing_keys=True) model(model.dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile( resolved_archive_file), "Error retrieving file {}".format( resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: model.load_weights(resolved_archive_file, by_name=True) except OSError: raise OSError( "Unable to load weights from h5 file. " "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " ) model(model.dummy_inputs, training=False) # Make sure restore ops are run # Check if the models are the same to output loading informations with h5py.File(resolved_archive_file, "r") as f: if "layer_names" not in f.attrs and "model_weights" in f: f = f["model_weights"] hdf5_layer_names = set( hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layer_names = set(layer.name for layer in model.layers) missing_keys = list(model_layer_names - hdf5_layer_names) unexpected_keys = list(hdf5_layer_names - model_layer_names) error_msgs = [] if len(missing_keys) > 0: logger.info( "Layers of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Layers from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading weights for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs } return model, loading_info return model
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r""" Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (:obj:`str`, `optional`): Can be either: - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g., ``bert-base-uncased``. - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g., ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``). model_args (sequence of positional arguments, `optional`): All remaning positional arguments will be passed to the underlying model's ``__init__`` method. config (:obj:`Union[PretrainedConfig, str]`, `optional`): Can be either: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the `shortcut name` string of a pretrained model). - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch state_dict save file (see docstring of ``pretrained_model_name_or_path`` argument). cache_dir (:obj:`str`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies: (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error messages. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (e.g., not try doanloading the model). use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: from transformers import BertConfig, TFBertModel # Download model and configuration from S3 and cache. model = TFBertModel.from_pretrained('bert-base-uncased') # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). model = TFBertModel.from_pretrained('./test/saved_model/') # Update configuration during loading. model = TFBertModel.from_pretrained('bert-base-uncased', output_attention=True) assert model.config.output_attention == True # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable). config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json') model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_cdn = kwargs.pop("use_cdn", True) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_pt` set to False" .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME), use_cdn=use_cdn, ) try: # Load from URL or cache if already cached resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) if resolved_archive_file is None: raise EnvironmentError except EnvironmentError: msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if from_pt: # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model( model, resolved_archive_file, allow_missing_keys=True) model(model.dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile( resolved_archive_file), "Error retrieving file {}".format( resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: model.load_weights(resolved_archive_file, by_name=True) except OSError: raise OSError( "Unable to load weights from h5 file. " "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " ) model(model.dummy_inputs, training=False) # Make sure restore ops are run # Check if the models are the same to output loading informations with h5py.File(resolved_archive_file, "r") as f: if "layer_names" not in f.attrs and "model_weights" in f: f = f["model_weights"] hdf5_layer_names = set( hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layer_names = set(layer.name for layer in model.layers) missing_keys = list(model_layer_names - hdf5_layer_names) unexpected_keys = list(hdf5_layer_names - model_layer_names) error_msgs = [] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.warning( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.warning( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading weights for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs } return model, loading_info return model
def from_pretrained_detailed(model_class, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) one of: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. from_pt: (`optional`) boolean, default False: Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument). cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. resume_download: (`optional`) boolean, default False: Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If layer pruning is supported, ``layer_pruning`` will passed as a dictionary contains layer pruning configurations as follows: - strategy: can be one of these values: {`top`, `buttom`, `symmetric`, `alternate`, `custom`} - k: is the number of layers to prune. mandatory if strategy is one of {`top`, `buttom`, `symmetric`, `alternate`} - layers_indexes: is array of layers indexs to prune. mandatory if strategy is `custom` - is_odd: is odd alternate or not. mandatory if strategy is `alternate` Examples:: # For example purposes. Not runnable. model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_cdn = kwargs.pop("use_cdn", True) # mwahdan: Read layer_pruning config if exist layer_pruning = kwargs.pop("layer_pruning", None) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = model_class.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_pt` set to False" .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME), ) try: # Load from URL or cache if already cached resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) if resolved_archive_file is None: raise EnvironmentError except EnvironmentError: msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # mwahdan: Modify config if layer_pruning: layer_pruning_k = layer_pruning_layers_indexes = layer_pruning_is_odd = None layer_pruning_strategy = get_mandatory_parameter( 'strategy', layer_pruning) if layer_pruning_strategy in {'top', 'buttom', 'symmetric'}: layer_pruning_k = get_mandatory_parameter('k', layer_pruning) config, original_num_layers = modify_num_of_layers( config, k=layer_pruning_k) elif layer_pruning_strategy == 'custom': layer_pruning_layers_indexes = get_mandatory_parameter( 'layers_indexes', layer_pruning) config, original_num_layers = modify_num_of_layers( config, layers_indexes=layer_pruning_layers_indexes) elif layer_pruning_strategy == 'alternate': layer_pruning_k = get_mandatory_parameter('k', layer_pruning) layer_pruning_is_odd = get_mandatory_parameter( 'is_odd', layer_pruning) config, original_num_layers = modify_num_of_layers( config, k=layer_pruning_k, is_alternate=True) else: raise Exception('`%s` is not a supported layer pruning strategy' % layer_pruning_strategy) # Instantiate model. model = model_class(config, *model_args, **model_kwargs) # mwahdan: Rename layers if layer_pruning: model = rename_layers_in_strategy(model, layer_pruning_strategy, original_num_layers, layer_pruning_k, layer_pruning_layers_indexes, layer_pruning_is_odd) if from_pt: # Load from a PyTorch checkpoint model = load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True) # mwahdan: Rename layers if layer_pruning is not None: model = rename_layers(model) return model model(model.dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile( resolved_archive_file), "Error retrieving file {}".format( resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: # added skip_mismatch=True because we will prune full layers model.load_weights(resolved_archive_file, by_name=True, skip_mismatch=True) # mwahdan: Rename layers except OSError: raise OSError( "Unable to load weights from h5 file. " "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " ) model(model.dummy_inputs, training=False) # Make sure restore ops are run # mwahdan: Rename layers if layer_pruning is not None: model = rename_layers(model) # Check if the models are the same to output loading informations with h5py.File(resolved_archive_file, "r") as f: if "layer_names" not in f.attrs and "model_weights" in f: f = f["model_weights"] hdf5_layer_names = set( hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layer_names = set(layer.name for layer in model.layers) missing_keys = list(model_layer_names - hdf5_layer_names) unexpected_keys = list(hdf5_layer_names - model_layer_names) error_msgs = [] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.warning( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.warning( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the ckeckpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(error_msgs) > 0: raise RuntimeError("Error(s) in loading weights for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs } return model, loading_info return model
def load_weights_from_hdf5_group_by_name_mapping(f, layers, name_mapping, skip_mismatch=False): """Implements name-based weight loading. (instead of topological weight loading). Layers that have no matching name are skipped. Args: f: A pointer to a HDF5 group. layers: a list of target layers. name_mapping : name mapping dict skip_mismatch: Boolean, whether to skip loading of layers where there is a mismatch in the number of weights, or a mismatch in the shape of the weights. Raises: ValueError: in case of mismatch between provided layers and weights file and skip_match=False. """ if 'keras_version' in f.attrs: original_keras_version = f.attrs['keras_version'] if hasattr(original_keras_version, 'decode'): original_keras_version = original_keras_version.decode('utf8') else: original_keras_version = '1' if 'backend' in f.attrs: original_backend = f.attrs['backend'] if hasattr(original_backend, 'decode'): original_backend = original_backend.decode('utf8') else: original_backend = None # New file format. layer_names = load_attributes_from_hdf5_group(f, 'layer_names') # Reverse index of layer name to list of layers with name. index = {} for layer in layers: if layer.name: index.setdefault(layer.name, []).append(layer) # We batch weight value assignments in a single backend call # which provides a speedup in TensorFlow. weight_value_tuples = [] for k, name in enumerate(layer_names): g = f[name] weight_names = load_attributes_from_hdf5_group(g, 'weight_names') weight_values = [ np.asarray(g[weight_name]) for weight_name in weight_names ] for layer in index.get(name, []): symbolic_weights = _legacy_weights(layer) weight_values = preprocess_weights_for_loading( layer, weight_values, original_keras_version, original_backend) if len(weight_values) != len(symbolic_weights): if skip_mismatch: logging.warning( 'Skipping loading of weights for ' 'layer {}'.format(layer.name) + ' due to mismatch ' 'in number of weights ({} vs {}).'.format( len(symbolic_weights), len(weight_values))) continue raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + '") expects ' + str(len(symbolic_weights)) + ' weight(s), but the saved weights' + ' have ' + str(len(weight_values)) + ' element(s).') # Set values. for i in range(len(weight_values)): if backend.int_shape( symbolic_weights[i]) != weight_values[i].shape: if skip_mismatch: logging.warning('Skipping loading of weights for ' 'layer {}'.format(layer.name) + ' due to ' 'mismatch in shape ({} vs {}).'.format( symbolic_weights[i].shape, weight_values[i].shape)) continue raise ValueError( 'Layer #' + str(k) + ' (named "' + layer.name + '"), weight ' + str(symbolic_weights[i]) + ' has shape {}'.format( backend.int_shape(symbolic_weights[i])) + ', but the saved weight has shape ' + str(weight_values[i].shape) + '.') else: weight_value_tuples.append( (symbolic_weights[i], weight_values[i])) backend.batch_set_value(weight_value_tuples)