def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] else: vocab_file = pretrained_model_name if os.path.isdir(vocab_file): vocab_file = os.path.join(vocab_file, VOCAB_NAME) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), vocab_file)) return None if resolved_vocab_file == vocab_file: logger.info("loading vocabulary file {}".format(vocab_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) return tokenizer
def get_config(cls, pretrained_model_name, cache_dir=None, type_vocab_size=2, state_dict=None): ''' abstract from `from_pretrained` ''' archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) if os.path.exists(archive_file) is False: if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[ pretrained_model_name] else: archive_file = pretrained_model_name # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list. " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name, archive_file)) return None if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: logger.info("loading archive file {} from cache at {}".format( archive_file, resolved_archive_file)) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info("extracting archive file {} to temp dir {}".format( resolved_archive_file, tempdir)) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) config = BertConfig.from_json_file(config_file) config.type_vocab_size = type_vocab_size if state_dict is None: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) if os.path.exists(weights_path): state_dict = torch.load(weights_path) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return config, state_dict
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ pretrained_model_name_or_path] if '-cased' in pretrained_model_name_or_path and kwargs.get( 'do_lower_case', True): logger.warning( "The pre-trained model you are loading is a cased model but you have not set " "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " "you may want to check this behavior.") kwargs['do_lower_case'] = False elif '-cased' not in pretrained_model_name_or_path and not kwargs.get( 'do_lower_case', True): logger.warning( "The pre-trained model you are loading is an uncased model but you have set " "`do_lower_case` to False. We are setting `do_lower_case=True` for you " "but you may want to check this behavior.") kwargs['do_lower_case'] = True else: vocab_file = pretrained_model_name_or_path if os.path.isdir(vocab_file): vocab_file = os.path.join(vocab_file, VOCAB_NAME) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), vocab_file)) return None if resolved_vocab_file == vocab_file: logger.info("loading vocabulary file {}".format(vocab_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) return tokenizer
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ pretrained_model_name_or_path] else: vocab_file = pretrained_model_name_or_path if os.path.isdir(vocab_file): vocab_file = os.path.join(vocab_file, VOCAB_NAME) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), vocab_file)) return None if resolved_vocab_file == vocab_file: logger.info("loading vocabulary file {}".format(vocab_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. resolved_vocab_file = os.path.abspath( '%s/model_dir/%s/vocab.txt' % (os.path.dirname( os.path.dirname(__file__)), 'chinese_wwm_ext_pytorch')) tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) return tokenizer
def _from_pretrained(cls, pretrained_model_name, cache_dir=None, *init_inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name in cls.pretrained_vocab_files_map: vocab_file = cls.pretrained_vocab_files_map[pretrained_model_name] else: vocab_file = pretrained_model_name if os.path.isdir(vocab_file): vocab_file = os.path.join(vocab_file, self.vocab_file_names) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name, ', '.join(cls.pretrained_vocab_files_map.keys()), vocab_file)) return None if resolved_vocab_file == vocab_file: logger.info("loading vocabulary file {}".format(vocab_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) if pretrained_model_name in self.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings max_len = self.max_model_input_sizes[pretrained_model_name] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) return tokenizer
def from_pretrained(cls, pretrained_model_name_or_path='gpt2-medium', state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs): """ Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: . `gpt2` - a path or url to a pretrained model archive containing: . `gpt2_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance - a path or url to a pretrained model archive containing: . `gpt2_config.json` a configuration file for the model . a TensorFlow checkpoint with trained weights from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models *inputs, **kwargs: additional input for the specific GPT class """ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[ pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[ pretrained_model_name_or_path] else: archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # redirect to the cache, if necessary print(archive_file) try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find files {} and {} " "at this path or url.".format( pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, archive_file, config_file)) return None if resolved_archive_file == archive_file and resolved_config_file == config_file: logger.info("loading weights file {}".format(archive_file)) logger.info("loading configuration file {}".format(config_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) logger.info( "loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) # Load config config = GPT2Config.from_json_file(resolved_config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint (stored as NumPy array) return load_tf_weights_in_gpt2(model, resolved_archive_file) old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if key.endswith(".g"): new_key = key[:-2] + ".weight" elif key.endswith(".b"): new_key = key[:-2] + ".bias" elif key.endswith(".w"): new_key = key[:-2] + ".weight" if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") start_model = model if hasattr(model, "transformer") and all( not s.startswith('transformer.') for s in state_dict.keys()): start_model = model.transformer load(start_model, prefix="") if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) # Make sure we are still sharing the output and input embeddings after loading weights model.set_tied() return model
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): r""" Instantiate a PretrainedConfig from a pre-trained model configuration. Params: **pretrained_model_name_or_path**: either: - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). - a path to a `directory` containing a configuration file saved using the `save_pretrained(save_directory)` method. - a path or url to a saved configuration `file`. **cache_dir**: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. **kwargs**: (`optional`) dict: Dictionnary of key, values to update the configuration object after loading. Can be used to override selected configuration parameters. Examples:: >>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. >>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` >>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True) >>> assert config.output_attention == True """ cache_dir = kwargs.pop('cache_dir', None) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) else: config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained model configuration file." .format(config_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_config_archive_map.keys()), config_file)) return None if resolved_config_file == config_file: logger.info("loading configuration file {}".format(config_file)) else: logger.info( "loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) # Load config config = cls.from_json_file(resolved_config_file) # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info("Model config %s", config) return config
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) To train the model, you should first set it back in training mode with `model.train()` Params: **pretrained_model_name_or_path**: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). - a path to a `directory` containing a configuration file saved using the `save_pretrained(save_directory)` method. - a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as `config` argument. This loading option is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. **config**: an optional configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or - the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory). **state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. This option can be used if you want to create a model from a pretrained configuraton but load your own weights. In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not a simpler option. **cache_dir**: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. **output_loading_info**: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. **kwargs**: (`optional`) dict: Dictionnary of key, values to update the configuration object after loading. Can be used to override selected configuration parameters. E.g. ``output_attention=True`` Examples:: >>> model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. >>> model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` >>> model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading >>> assert model.config.output_attention == True >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) >>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') >>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ config = kwargs.pop('config', None) state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) output_loading_info = kwargs.pop('output_loading_info', False) # Load config if config is None: config = cls.config_class.from_pretrained( pretrained_model_name_or_path, *inputs, **kwargs) # Load model if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if from_tf: # Directly load from a TensorFlow checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") else: archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: if from_tf: # Directly load from a TensorFlow checkpoint archive_file = pretrained_model_name_or_path + ".index" else: archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained weights." .format(archive_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file)) return None if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) # Instantiate model. model = cls(config) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint return cls.load_tf_weights( model, config, resolved_archive_file[:-6]) # Remove the '.index' # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # Load from a PyTorch state_dict missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') # Make sure we are able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model if not hasattr(model, cls.base_model_prefix) and any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + '.' if hasattr(model, cls.base_model_prefix) and not any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( 'Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) if hasattr(model, 'tie_weights'): model.tie_weights( ) # make sure word embedding weights are still tied # Set model in evaluation mode to desactivate DropOut modules by default model.eval() if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs } return model, loading_info return model
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs): """ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: . `bert-base-uncased` . `bert-large-uncased` . `bert-base-cased` . `bert-large-cased` . `bert-base-multilingual-uncased` . `bert-base-multilingual-cased` . `bert-base-chinese` - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `model.chkpt` a TensorFlow checkpoint from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[ pretrained_model_name_or_path] else: archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file)) return None if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: logger.info("loading archive file {} from cache at {}".format( archive_file, resolved_archive_file)) tempdir = None if os.path.isdir(resolved_archive_file) or from_tf: serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info("extracting archive file {} to temp dir {}".format( resolved_archive_file, tempdir)) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) config = BertConfig.from_json_file(config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load( weights_path, map_location='cpu' if not torch.cuda.is_available() else None) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) if from_tf: # Directly load from a TensorFlow checkpoint weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) return load_tf_weights_in_bert(model, weights_path) # Load from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') start_prefix = '' if not hasattr(model, 'bert') and any( s.startswith('bert.') for s in state_dict.keys()): start_prefix = 'bert.' load(model, prefix=start_prefix) if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( 'Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) return model
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a GPT2Tokenizer from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ pretrained_model_name_or_path] merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[ pretrained_model_name_or_path] special_tokens_file = None else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) if not os.path.exists(special_tokens_file): special_tokens_file = None else: logger.info("loading special tokens file {}".format( special_tokens_file)) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: logger.error( "Couldn't reach server at '{}' to download vocabulary.". format(vocab_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find files {} and {} " "at this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, vocab_file, merges_file)) return None if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: logger.info("loading vocabulary file {}".format(vocab_file)) logger.info("loading merges file {}".format(merges_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) logger.info("loading merges file {} from cache at {}".format( merges_file, resolved_merges_file)) if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. if special_tokens_file and 'special_tokens' not in kwargs: special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] else: special_tokens = kwargs.pop('special_tokens', []) tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) return tokenizer
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = kwargs.pop('config', None) state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) if config is None: config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, proxies=proxies, **kwargs ) else: model_kwargs = kwargs if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path) archive_file = pretrained_model_name_or_path + ".index" try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: msg = "Couldn't reach server at '{}' to download pretrained weights.".format( archive_file) else: msg = "Model name '{}' was not found in model name list ({}). " \ "We assumed '{}' was a path or url to model weight files named one of {} but " \ "couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file, [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: #Not required for now pass else: # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') # Make sure we are able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + '.' if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) if hasattr(model, 'tie_weights'): model.tie_weights() # make sure word embedding weights are still tied # Set model in evaluation mode to desactivate DropOut modules by default model.eval() if output_loading_info: loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} return model, loading_info return model
def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. Params: pretrained_model_name: either: - a str with the name of a pre-trained model to load selected in the list of: . `bert-base-uncased` . `bert-large-uncased` . `bert-base-cased` . `bert-base-multilingual` . `bert-base-chinese` - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] else: archive_file = pretrained_model_name # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name, ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name)) return None if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: logger.info("loading archive file {} from cache at {}".format( archive_file, resolved_archive_file)) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info("extracting archive file {} to temp dir {}".format( resolved_archive_file, tempdir)) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) config = BertConfig.from_json_file(config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load(weights_path, map_location = 'cpu') missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model, prefix='' if hasattr(model, 'bert') else 'bert.') if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return model
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a :class:`~transformers.PretrainedConfig` using ``from_dict``. Parameters: pretrained_model_name_or_path (:obj:`str`): The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. Returns: :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) try: # Load from URL or cache if already cached resolved_config_file = cached_path( config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) # Load config dict if resolved_config_file is None: raise EnvironmentError config_dict = cls._dict_from_json_file(resolved_config_file) except EnvironmentError: msg = ( f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n" ) raise EnvironmentError(msg) except json.JSONDecodeError: msg = ( "Couldn't reach server at '{}' to download configuration file or " "configuration file is not a valid JSON file. " "Please check network or file content here: {}.".format( config_file, resolved_config_file)) raise EnvironmentError(msg) if resolved_config_file == config_file: logger.info("loading configuration file {}".format(config_file)) else: logger.info( "loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) return config_dict, kwargs
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) To train the model, you should first set it back in training mode with ``model.train()`` The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. from_pt: (`optional`) boolean, default False: Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument). cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config) """ config = kwargs.pop('config', None) cache_dir = kwargs.pop('cache_dir', None) from_pt = kwargs.pop('from_pt', False) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) # Load config if config is None: config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, **kwargs) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_pt` set to False" .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: raise EnvironmentError("Error file {} not found".format( pretrained_model_name_or_path)) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError as e: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained weights." .format(archive_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file)) raise e if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if from_pt: # Load from a PyTorch checkpoint return cls.load_pt_weights(model, resolved_archive_file) inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) ret = model(inputs, training=False) # build the network with dummy inputs assert os.path.isfile( resolved_archive_file), "Error retrieving file {}".format( resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 model.load_weights(resolved_archive_file, by_name=True) ret = model(inputs, training=False) # Make sure restore ops are run return model
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) To train the model, you should first set it back in training mode with ``model.train()`` The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. state_dict: (`optional`) dict: an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ config = kwargs.pop('config', None) state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) from_tf = kwargs.pop('from_tf', False) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) output_loading_info = kwargs.pop('output_loading_info', False) # Load config if config is None: config, model_kwargs = cls.config_class.from_pretrained( pretrained_model_name_or_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, **kwargs ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): # Load from a TF 1.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path) archive_file = pretrained_model_name_or_path + ".index" # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: msg = "Couldn't reach server at '{}' to download pretrained weights.".format( archive_file) else: msg = "Model name '{}' was not found in model name list ({}). " \ "We assumed '{}' was a path or url to model weight files named one of {} but " \ "couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file, [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') missing_keys = [] unexpected_keys = [] error_msgs = [] if from_tf: if resolved_archive_file.endswith('.index'): # Load from a TensorFlow 1.X checkpoint - provided by original authors model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' else: # Load from our TensorFlow 2.0 checkpoints try: from transformers import load_tf2_checkpoint_in_pytorch_model model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) except ImportError as e: logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") raise e else: # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') # Make sure we are able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): start_prefix = cls.base_model_prefix + '.' if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) if hasattr(model, 'tie_weights'): model.tie_weights() # make sure word embedding weights are still tied # Set model in evaluation mode to desactivate DropOut modules by default model.eval() if output_loading_info: loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} return model, loading_info return model
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. return_unused_kwargs: (`optional`) bool: - If False, then this function returns just the final configuration object. - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. Examples:: # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a # derived class: BertConfig config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) assert config.output_attention == True config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False, return_unused_kwargs=True) assert config.output_attention == True assert unused_kwargs == {'foo': False} """ cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) else: config_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError as e: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained model configuration file." .format(config_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_config_archive_map.keys()), config_file)) raise e if resolved_config_file == config_file: logger.info("loading configuration file {}".format(config_file)) else: logger.info( "loading configuration file {} from cache at {}".format( config_file, resolved_config_file)) # Load config config = cls.from_json_file(resolved_config_file) if hasattr(config, 'pruned_heads'): config.pruned_heads = dict( (int(key), set(value)) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info("Model config %s", config) if return_unused_kwargs: return config, kwargs else: return config
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) To train the model, you should first set it back in training mode with ``model.train()`` The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) one of: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. from_pt: (`optional`) boolean, default False: Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument). cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. resume_download: (`optional`) boolean, default False: Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: # For example purposes. Not runnable. model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: archive_file = cls.pretrained_model_archive_map[ pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile( os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_pt` set to False" .format([WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path)) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME) if from_pt: raise EnvironmentError( "Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name." ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, ) except EnvironmentError as e: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained weights." .format(archive_file)) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ", ".join(cls.pretrained_model_archive_map.keys()), archive_file, )) raise e if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if from_pt: # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model( model, resolved_archive_file, allow_missing_keys=True) model(model.dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile( resolved_archive_file), "Error retrieving file {}".format( resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: model.load_weights(resolved_archive_file, by_name=True) except OSError: raise OSError( "Unable to load weights from h5 file. " "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " ) model(model.dummy_inputs, training=False) # Make sure restore ops are run # Check if the models are the same to output loading informations with h5py.File(resolved_archive_file, "r") as f: if "layer_names" not in f.attrs and "model_weights" in f: f = f["model_weights"] hdf5_layer_names = set( hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layer_names = set(layer.name for layer in model.layers) missing_keys = list(model_layer_names - hdf5_layer_names) unexpected_keys = list(hdf5_layer_names - model_layer_names) error_msgs = [] if len(missing_keys) > 0: logger.info( "Layers of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Layers from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading weights for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) if output_loading_info: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs } return model, loading_info return model
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None, *inputs, **kwargs): """ Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name: either: - a str with the name of a pre-trained model to load selected in the list of: . `openai-gpt` - a path or url to a pretrained model archive containing: . `openai_gpt_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] else: archive_file = pretrained_model_name # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name, ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file)) return None if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: logger.info("loading archive file {} from cache at {}".format( archive_file, resolved_archive_file)) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info("extracting archive file {} to temp dir {}".format( resolved_archive_file, tempdir)) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) config = OpenAIGPTConfig.from_json_file(config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load(weights_path) old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model.transformer if hasattr(model, 'transformer') else model, prefix='') if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: raise RuntimeError( 'Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) # Add additional embeddings for special tokens if needed if num_special_tokens != config.n_special: model.set_num_special_tokens(num_special_tokens) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return model
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} if pretrained_model_name_or_path in s3_models: # Get the vocabulary from AWS S3 bucket for file_id, map_list in cls.pretrained_vocab_files_map.items(): vocab_files[file_id] = map_list[pretrained_model_name_or_path] else: # Get the vocabulary from local files logger.info( "Model name '{}' not found in model shortcut name list ({}). " "Assuming '{}' is a path or url to a directory containing tokenizer files." .format(pretrained_model_name_or_path, ', '.join(s3_models), pretrained_model_name_or_path)) # Look for the tokenizer main vocabulary files for file_id, file_name in cls.vocab_files_names.items(): if os.path.isdir(pretrained_model_name_or_path): # If a directory is provided we look for the standard filenames full_file_name = os.path.join( pretrained_model_name_or_path, file_name) else: # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file) full_file_name = pretrained_model_name_or_path if not os.path.exists(full_file_name): logger.info( "Didn't find file {}. We won't load it.".format( full_file_name)) full_file_name = None vocab_files[file_id] = full_file_name # Look for the additional tokens files all_vocab_files_names = { 'added_tokens_file': ADDED_TOKENS_FILE, 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE } # If a path to a file was provided, get the parent directory saved_directory = pretrained_model_name_or_path if os.path.exists( saved_directory) and not os.path.isdir(saved_directory): saved_directory = os.path.dirname(saved_directory) for file_id, file_name in all_vocab_files_names.items(): full_file_name = os.path.join(saved_directory, file_name) if not os.path.exists(full_file_name): logger.info( "Didn't find file {}. We won't load it.".format( full_file_name)) full_file_name = None vocab_files[file_id] = full_file_name if all(full_file_name is None for full_file_name in vocab_files.values()): logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find tokenizer files" "at this path or url.".format( pretrained_model_name_or_path, ', '.join(s3_models), pretrained_model_name_or_path, )) return None # Get files from url, cache, or disk depending on the case try: resolved_vocab_files = {} for file_id, file_path in vocab_files.items(): if file_path is None: resolved_vocab_files[file_id] = None else: resolved_vocab_files[file_id] = cached_path( file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies) except EnvironmentError as e: if pretrained_model_name_or_path in s3_models: logger.error("Couldn't reach server to download vocabulary.") else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find files {} " "at this path or url.".format( pretrained_model_name_or_path, ', '.join(s3_models), pretrained_model_name_or_path, str(vocab_files.keys()))) raise e for file_id, file_path in vocab_files.items(): if file_path == resolved_vocab_files[file_id]: logger.info("loading file {}".format(file_path)) else: logger.info("loading file {} from cache at {}".format( file_path, resolved_vocab_files[file_id])) # Set max length if needed if pretrained_model_name_or_path in cls.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] if max_len is not None and isinstance(max_len, (int, float)): kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Merge resolved_vocab_files arguments in kwargs. added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) special_tokens_map_file = resolved_vocab_files.pop( 'special_tokens_map_file', None) for args_name, file_path in resolved_vocab_files.items(): if args_name not in kwargs: kwargs[args_name] = file_path if special_tokens_map_file is not None: special_tokens_map = json.load( open(special_tokens_map_file, encoding="utf-8")) for key, value in special_tokens_map.items(): if key not in kwargs: kwargs[key] = value # Instantiate tokenizer. tokenizer = cls(*inputs, **kwargs) # Add supplementary tokens. if added_tokens_file is not None: added_tok_encoder = json.load( open(added_tokens_file, encoding="utf-8")) added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} tokenizer.added_tokens_encoder.update(added_tok_encoder) tokenizer.added_tokens_decoder.update(added_tok_decoder) return tokenizer
def get_config_dict( cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs ) -> Tuple[Dict, Dict]: """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a Config using `from_dict`. Parameters: pretrained_model_name_or_path (:obj:`string`): The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict: A map of `shortcut names` to `url`. By default, will use the current class attribute. Returns: :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) if pretrained_config_archive_map is None: pretrained_config_archive_map = cls.pretrained_config_archive_map if pretrained_model_name_or_path in pretrained_config_archive_map: config_file = pretrained_config_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME) try: # Load from URL or cache if already cached resolved_config_file = cached_path( config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) # Load config dict if resolved_config_file is None: raise EnvironmentError config_dict = cls._dict_from_json_file(resolved_config_file) except EnvironmentError: if pretrained_model_name_or_path in pretrained_config_archive_map: msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( config_file ) else: msg = ( "Can't load '{}'. Make sure that:\n\n" "- '{}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" "- or '{}' is the correct path to a directory containing a '{}' file\n\n".format( pretrained_model_name_or_path, pretrained_model_name_or_path, pretrained_model_name_or_path, CONFIG_NAME, ) ) raise EnvironmentError(msg) except json.JSONDecodeError: msg = ( "Couldn't reach server at '{}' to download configuration file or " "configuration file is not a valid JSON file. " "Please check network or file content here: {}.".format(config_file, resolved_config_file) ) raise EnvironmentError(msg) if resolved_config_file == config_file: logger.info("loading configuration file {}".format(config_file)) else: logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) return config_dict, kwargs
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] else: vocab_file = pretrained_model_name if os.path.isdir(vocab_file): vocab_file = os.path.join(vocab_file, VOCAB_NAME) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), vocab_file)) return None if resolved_vocab_file == vocab_file: logger.info("loading vocabulary file {}".format(vocab_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ pretrained_model_name] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) kwargs['never_split'] = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[Q]", "[A]") # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) def add_tokens(new_tokens): unused_to_new_map = {} for i in range(len(tokenizer.ids_to_tokens)): if tokenizer.ids_to_tokens[i].find("[unused") != -1: unused_token = tokenizer.ids_to_tokens[i] new_token = new_tokens[0] unused_to_new_map[unused_token] = new_token tokenizer.ids_to_tokens[i] = new_token new_tokens = new_tokens[1:] if len(new_tokens) == 0: break vocab = collections.OrderedDict() for token in tokenizer.vocab.keys(): if token in unused_to_new_map: vocab[unused_to_new_map[token]] = tokenizer.vocab[token] else: vocab[token] = tokenizer.vocab[token] tokenizer.vocab = tokenizer.wordpiece_tokenizer.vocab = vocab add_tokens(["[Q]", "[A]"]) return tokenizer