コード例 #1
0
 def test_no_connection(self):
     invalid_url = hf_hub_url(
         MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID
     )
     valid_url = hf_hub_url(
         MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT
     )
     self.assertIsNotNone(cached_download(valid_url, force_download=True))
     for offline_mode in OfflineSimulationMode:
         with offline(mode=offline_mode):
             with self.assertRaisesRegex(ValueError, "Connection error"):
                 _ = cached_download(invalid_url)
             with self.assertRaisesRegex(ValueError, "Connection error"):
                 _ = cached_download(valid_url, force_download=True)
             self.assertIsNotNone(cached_download(valid_url))
コード例 #2
0
 def test_lfs_object(self):
     url = hf_hub_url(
         MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT
     )
     filepath = cached_download(url, force_download=True)
     metadata = filename_to_url(filepath)
     self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
コード例 #3
0
    def test_upload_buffer(self):
        REPO_NAME = repo_name("buffer")
        self._api.create_repo(repo_id=REPO_NAME, token=self._token)
        try:
            buffer = BytesIO()
            buffer.write(self.tmp_file_content.encode())
            self._api.upload_file(
                path_or_fileobj=buffer.getvalue(),
                path_in_repo="temp/new_file.md",
                repo_id=f"{USER}/{REPO_NAME}",
                token=self._token,
            )
            url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format(
                ENDPOINT_STAGING,
                user=USER,
                repo=REPO_NAME,
            )
            filepath = cached_download(url, force_download=True)
            with open(filepath) as downloaded_file:
                content = downloaded_file.read()
            self.assertEqual(content, self.tmp_file_content)

        except Exception as err:
            self.fail(err)
        finally:
            self._api.delete_repo(repo_id=REPO_NAME, token=self._token)
コード例 #4
0
    def test_upload_file_bytesio(self):
        REPO_NAME = repo_name("bytesio")
        self._api.create_repo(name=REPO_NAME, token=self._token)
        try:
            filecontent = BytesIO(b"File content, but in bytes IO")
            self._api.upload_file(
                path_or_fileobj=filecontent,
                path_in_repo="temp/new_file.md",
                repo_id=f"{USER}/{REPO_NAME}",
                token=self._token,
            )
            url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format(
                ENDPOINT_STAGING,
                user=USER,
                repo=REPO_NAME,
            )
            filepath = cached_download(url, force_download=True)
            with open(filepath) as downloaded_file:
                content = downloaded_file.read()
            self.assertEqual(content, filecontent.getvalue().decode())

        except Exception as err:
            self.fail(err)
        finally:
            self._api.delete_repo(name=REPO_NAME, token=self._token)
コード例 #5
0
    def test_upload_file_fileobj(self):
        REPO_NAME = repo_name("fileobj")
        self._api.create_repo(name=REPO_NAME, token=self._token)
        try:
            with open(self.tmp_file, "rb") as filestream:
                self._api.upload_file(
                    path_or_fileobj=filestream,
                    path_in_repo="temp/new_file.md",
                    repo_id=f"{USER}/{REPO_NAME}",
                    token=self._token,
                )
            url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format(
                ENDPOINT_STAGING,
                user=USER,
                repo=REPO_NAME,
            )
            filepath = cached_download(url, force_download=True)
            with open(filepath) as downloaded_file:
                content = downloaded_file.read()
            self.assertEqual(content, self.tmp_file_content)

        except Exception as err:
            self.fail(err)
        finally:
            self._api.delete_repo(name=REPO_NAME, token=self._token)
コード例 #6
0
 def test_standard_object_rev(self):
     # Same object, but different revision
     url = hf_hub_url(
         MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT
     )
     filepath = cached_download(url, force_download=True)
     metadata = filename_to_url(filepath)
     self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
コード例 #7
0
 def test_dataset_lfs_object(self):
     url = hf_hub_url(
         DATASET_ID,
         filename="dev-v1.1.json",
         repo_type=REPO_TYPE_DATASET,
         revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
     )
     filepath = cached_download(url, force_download=True)
     metadata = filename_to_url(filepath)
     self.assertEqual(
         metadata,
         (url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'),
     )
コード例 #8
0
 def test_dataset_standard_object_rev(self):
     url = hf_hub_url(
         DATASET_ID,
         filename=DATASET_SAMPLE_PY_FILE,
         repo_type=REPO_TYPE_DATASET,
         revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
     )
     # We can also just get the same url by prefixing "datasets" to repo_id:
     url2 = hf_hub_url(
         repo_id=f"datasets/{DATASET_ID}",
         filename=DATASET_SAMPLE_PY_FILE,
         revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
     )
     self.assertEqual(url, url2)
     # now let's download
     filepath = cached_download(url, force_download=True)
     metadata = filename_to_url(filepath)
     self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
コード例 #9
0
ファイル: HFModelHub.py プロジェクト: ajayarunachalam/glasses
    def from_pretrained(
        pretrained_model_name_or_path: Optional[str],
        strict: bool = True,
        map_location: Optional[str] = "cpu",
        force_download: bool = False,
        resume_download: bool = False,
        proxies: Dict = None,
        use_auth_token: Optional[str] = None,
        cache_dir: Optional[str] = None,
        local_files_only: bool = False,
    ) -> StateDict:
        r"""
        Instantiate a pretrained pytorch model from a pre-trained model configuration from huggingface-hub.
        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To
        train the model, you should first set it back in training mode with ``model.train()``.

        Parameters:
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
                Can be either:
                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - You can add `revision` by appending `@` at the end of model_id simply like this: ``dbmdz/bert-base-german-cased@main``
                      Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id,
                      since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments ``config`` and ``state_dict``).
            cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str], `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
            use_auth_token (:obj:`str` or `bool`, `optional`):
                The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
                generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
            model_kwargs (:obj:`Dict`, `optional`)::
                model_kwargs will be passed to the model during initialization
        .. note::
            Passing :obj:`use_auth_token=True` is required when you want to use a private model.
        """

        model_id = pretrained_model_name_or_path
        map_location = torch.device(map_location)

        revision = None
        if len(model_id.split("@")) == 2:
            model_id, revision = model_id.split("@")

        if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id):
            config_file = os.path.join(model_id, CONFIG_NAME)
        else:
            try:
                config_url = hf_hub_url(
                    model_id, filename=CONFIG_NAME, revision=revision
                )
                config_file = cached_download(
                    config_url,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                )
            except requests.exceptions.RequestException:
                logger.warning("config.json NOT FOUND in HuggingFace Hub")
                config_file = None

        if model_id in os.listdir():
            print("LOADING weights from local directory")
            model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
        else:
            model_url = hf_hub_url(
                model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision
            )
            model_file = cached_download(
                model_url,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
            )

            logger.debug(model_file)

        if config_file is not None:
            with open(config_file, "r", encoding="utf-8") as f:
                config = json.load(f)
                # we are not using config

        state_dict = torch.load(model_file, map_location=map_location)

        return state_dict
コード例 #10
0
 def test_standard_object(self):
     url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
     filepath = cached_download(url, force_download=True)
     metadata = filename_to_url(filepath)
     self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
コード例 #11
0
 def test_revision_not_found(self):
     # Valid file but missing revision
     url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
     with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
         _ = cached_download(url)
コード例 #12
0
 def test_file_not_found(self):
     # Valid revision (None) but missing file.
     url = hf_hub_url(MODEL_ID, filename="missing.bin")
     with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
         _ = cached_download(url)
コード例 #13
0
 def test_bogus_url(self):
     url = "https://bogus"
     with self.assertRaisesRegex(ValueError, "Connection error"):
         _ = cached_download(url)