def load_pretrained_xlnet(pretrained_model_name, cache_dir=None): """ Return the directory in which the pretrained model is cached. """ if pretrained_model_name in _MODEL2URL: download_path = _MODEL2URL[pretrained_model_name] else: raise ValueError( "Pre-trained model not found: {}".format(pretrained_model_name)) if cache_dir is None: cache_dir = default_download_dir("xlnet") file_name = download_path.split('/')[-1] # this is required because of the way xlnet model is bundled file_name = "xlnet_" + file_name cache_path = os.path.join(cache_dir, file_name.split('.')[0]) if not os.path.exists(cache_path): maybe_download(download_path, cache_dir, extract=True) else: print("Using cached pre-trained model {} from: {}".format( pretrained_model_name, cache_dir)) return cache_path
def download_checkpoint(cls, pretrained_model_name, cache_dir=None): r"""Download the specified pre-trained checkpoint, and return the directory in which the checkpoint is cached. Args: pretrained_model_name (str): Name of the model checkpoint. cache_dir (str, optional): Path to the cache directory. If `None`, uses the default directory (user's home directory). Returns: Path to the cache directory. """ if pretrained_model_name in cls._MODEL2URL: download_path = cls._MODEL2URL[pretrained_model_name] else: raise ValueError( "Pre-trained model not found: {}".format(pretrained_model_name)) if cache_dir is None: cache_path = default_download_dir(cls._MODEL_NAME) else: cache_path = Path(cache_dir) cache_path = cache_path / pretrained_model_name if not cache_path.exists(): if isinstance(download_path, list): for path in download_path: maybe_download(path, str(cache_path)) else: filename = download_path.split('/')[-1] maybe_download(download_path, str(cache_path), extract=True) folder = None for file in cache_path.iterdir(): if file.is_dir(): folder = file assert folder is not None (cache_path / filename).unlink() for file in folder.iterdir(): file.rename(file.parents[1] / file.name) folder.rmdir() print("Pre-trained {} checkpoint {} cached to {}".format( cls._MODEL_NAME, pretrained_model_name, cache_path)) else: print("Using cached pre-trained {} checkpoint from {}.".format( cls._MODEL_NAME, cache_path)) return str(cache_path)
def setUp(self): self.tmp_dir = tempfile.TemporaryDirectory() self.SAMPLE_VOCAB = maybe_download( 'https://github.com/gpengzhi/pytorch-transformers/blob/master/' 'pytorch_transformers/tests/fixtures/test_sentencepiece.model' '?raw=true', self.tmp_dir.name) self.tokenizer = XLNetTokenizer.load(self.SAMPLE_VOCAB[0], configs={'keep_accents': True}) self.tokenizer.save(self.tmp_dir.name)
def setUp(self): self.tmp_dir = tempfile.TemporaryDirectory() # Use the test sentencepiece model downloaded from huggingface # transformers self.SAMPLE_VOCAB = maybe_download( 'https://github.com/huggingface/transformers/blob/master/' 'transformers/tests/fixtures/test_sentencepiece.model?raw=true', self.tmp_dir.name) self.tokenizer = XLNetTokenizer.load(self.SAMPLE_VOCAB[0], configs={'keep_accents': True}) self.tokenizer.save(self.tmp_dir.name)