コード例 #1
0
ファイル: xlnet_utils.py プロジェクト: wj1721/texar
def load_pretrained_xlnet(pretrained_model_name, cache_dir=None):
    """
    Return the directory in which the pretrained model is cached.
    """
    if pretrained_model_name in _MODEL2URL:
        download_path = _MODEL2URL[pretrained_model_name]
    else:
        raise ValueError(
            "Pre-trained model not found: {}".format(pretrained_model_name))

    if cache_dir is None:
        cache_dir = default_download_dir("xlnet")

    file_name = download_path.split('/')[-1]
    # this is required because of the way xlnet model is bundled
    file_name = "xlnet_" + file_name

    cache_path = os.path.join(cache_dir, file_name.split('.')[0])
    if not os.path.exists(cache_path):
        maybe_download(download_path, cache_dir, extract=True)
    else:
        print("Using cached pre-trained model {} from: {}".format(
            pretrained_model_name, cache_dir))

    return cache_path
コード例 #2
0
    def download_checkpoint(cls, pretrained_model_name, cache_dir=None):
        r"""Download the specified pre-trained checkpoint, and return the
        directory in which the checkpoint is cached.

        Args:
            pretrained_model_name (str): Name of the model checkpoint.
            cache_dir (str, optional): Path to the cache directory. If `None`,
                uses the default directory (user's home directory).

        Returns:
            Path to the cache directory.
        """
        if pretrained_model_name in cls._MODEL2URL:
            download_path = cls._MODEL2URL[pretrained_model_name]
        else:
            raise ValueError(
                "Pre-trained model not found: {}".format(pretrained_model_name))

        if cache_dir is None:
            cache_path = default_download_dir(cls._MODEL_NAME)
        else:
            cache_path = Path(cache_dir)
        cache_path = cache_path / pretrained_model_name

        if not cache_path.exists():
            if isinstance(download_path, list):
                for path in download_path:
                    maybe_download(path, str(cache_path))
            else:
                filename = download_path.split('/')[-1]
                maybe_download(download_path, str(cache_path), extract=True)
                folder = None
                for file in cache_path.iterdir():
                    if file.is_dir():
                        folder = file
                assert folder is not None
                (cache_path / filename).unlink()
                for file in folder.iterdir():
                    file.rename(file.parents[1] / file.name)
                folder.rmdir()
            print("Pre-trained {} checkpoint {} cached to {}".format(
                cls._MODEL_NAME, pretrained_model_name, cache_path))
        else:
            print("Using cached pre-trained {} checkpoint from {}.".format(
                cls._MODEL_NAME, cache_path))

        return str(cache_path)
コード例 #3
0
    def setUp(self):
        self.tmp_dir = tempfile.TemporaryDirectory()
        self.SAMPLE_VOCAB = maybe_download(
            'https://github.com/gpengzhi/pytorch-transformers/blob/master/'
            'pytorch_transformers/tests/fixtures/test_sentencepiece.model'
            '?raw=true', self.tmp_dir.name)

        self.tokenizer = XLNetTokenizer.load(self.SAMPLE_VOCAB[0],
                                             configs={'keep_accents': True})
        self.tokenizer.save(self.tmp_dir.name)
コード例 #4
0
    def setUp(self):
        self.tmp_dir = tempfile.TemporaryDirectory()
        # Use the test sentencepiece model downloaded from huggingface
        # transformers
        self.SAMPLE_VOCAB = maybe_download(
            'https://github.com/huggingface/transformers/blob/master/'
            'transformers/tests/fixtures/test_sentencepiece.model?raw=true',
            self.tmp_dir.name)

        self.tokenizer = XLNetTokenizer.load(self.SAMPLE_VOCAB[0],
                                             configs={'keep_accents': True})
        self.tokenizer.save(self.tmp_dir.name)