def test_no_connection(self): invalid_url = hf_hub_url( MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID ) valid_url = hf_hub_url( MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT ) self.assertIsNotNone(cached_download(valid_url, force_download=True)) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): with self.assertRaisesRegex(ValueError, "Connection error"): _ = cached_download(invalid_url) with self.assertRaisesRegex(ValueError, "Connection error"): _ = cached_download(valid_url, force_download=True) self.assertIsNotNone(cached_download(valid_url))
def test_lfs_object(self): url = hf_hub_url( MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT ) filepath = cached_download(url, force_download=True) metadata = filename_to_url(filepath) self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
def test_upload_buffer(self): REPO_NAME = repo_name("buffer") self._api.create_repo(repo_id=REPO_NAME, token=self._token) try: buffer = BytesIO() buffer.write(self.tmp_file_content.encode()) self._api.upload_file( path_or_fileobj=buffer.getvalue(), path_in_repo="temp/new_file.md", repo_id=f"{USER}/{REPO_NAME}", token=self._token, ) url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format( ENDPOINT_STAGING, user=USER, repo=REPO_NAME, ) filepath = cached_download(url, force_download=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) except Exception as err: self.fail(err) finally: self._api.delete_repo(repo_id=REPO_NAME, token=self._token)
def test_upload_file_bytesio(self): REPO_NAME = repo_name("bytesio") self._api.create_repo(name=REPO_NAME, token=self._token) try: filecontent = BytesIO(b"File content, but in bytes IO") self._api.upload_file( path_or_fileobj=filecontent, path_in_repo="temp/new_file.md", repo_id=f"{USER}/{REPO_NAME}", token=self._token, ) url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format( ENDPOINT_STAGING, user=USER, repo=REPO_NAME, ) filepath = cached_download(url, force_download=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, filecontent.getvalue().decode()) except Exception as err: self.fail(err) finally: self._api.delete_repo(name=REPO_NAME, token=self._token)
def test_upload_file_fileobj(self): REPO_NAME = repo_name("fileobj") self._api.create_repo(name=REPO_NAME, token=self._token) try: with open(self.tmp_file, "rb") as filestream: self._api.upload_file( path_or_fileobj=filestream, path_in_repo="temp/new_file.md", repo_id=f"{USER}/{REPO_NAME}", token=self._token, ) url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format( ENDPOINT_STAGING, user=USER, repo=REPO_NAME, ) filepath = cached_download(url, force_download=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) except Exception as err: self.fail(err) finally: self._api.delete_repo(name=REPO_NAME, token=self._token)
def test_standard_object_rev(self): # Same object, but different revision url = hf_hub_url( MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT ) filepath = cached_download(url, force_download=True) metadata = filename_to_url(filepath) self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
def test_dataset_lfs_object(self): url = hf_hub_url( DATASET_ID, filename="dev-v1.1.json", repo_type=REPO_TYPE_DATASET, revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) filepath = cached_download(url, force_download=True) metadata = filename_to_url(filepath) self.assertEqual( metadata, (url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'), )
def test_dataset_standard_object_rev(self): url = hf_hub_url( DATASET_ID, filename=DATASET_SAMPLE_PY_FILE, repo_type=REPO_TYPE_DATASET, revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) # We can also just get the same url by prefixing "datasets" to repo_id: url2 = hf_hub_url( repo_id=f"datasets/{DATASET_ID}", filename=DATASET_SAMPLE_PY_FILE, revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) self.assertEqual(url, url2) # now let's download filepath = cached_download(url, force_download=True) metadata = filename_to_url(filepath) self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
def from_pretrained( pretrained_model_name_or_path: Optional[str], strict: bool = True, map_location: Optional[str] = "cpu", force_download: bool = False, resume_download: bool = False, proxies: Dict = None, use_auth_token: Optional[str] = None, cache_dir: Optional[str] = None, local_files_only: bool = False, ) -> StateDict: r""" Instantiate a pretrained pytorch model from a pre-trained model configuration from huggingface-hub. The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To train the model, you should first set it back in training mode with ``model.train()``. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`): Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - You can add `revision` by appending `@` at the end of model_id simply like this: ``dbmdz/bert-base-german-cased@main`` Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. - A path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``). cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (:obj:`str` or `bool`, `optional`): The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). model_kwargs (:obj:`Dict`, `optional`):: model_kwargs will be passed to the model during initialization .. note:: Passing :obj:`use_auth_token=True` is required when you want to use a private model. """ model_id = pretrained_model_name_or_path map_location = torch.device(map_location) revision = None if len(model_id.split("@")) == 2: model_id, revision = model_id.split("@") if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id): config_file = os.path.join(model_id, CONFIG_NAME) else: try: config_url = hf_hub_url( model_id, filename=CONFIG_NAME, revision=revision ) config_file = cached_download( config_url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, ) except requests.exceptions.RequestException: logger.warning("config.json NOT FOUND in HuggingFace Hub") config_file = None if model_id in os.listdir(): print("LOADING weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) else: model_url = hf_hub_url( model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision ) model_file = cached_download( model_url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, ) logger.debug(model_file) if config_file is not None: with open(config_file, "r", encoding="utf-8") as f: config = json.load(f) # we are not using config state_dict = torch.load(model_file, map_location=map_location) return state_dict
def test_standard_object(self): url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT) filepath = cached_download(url, force_download=True) metadata = filename_to_url(filepath) self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
def test_revision_not_found(self): # Valid file but missing revision url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID) with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): _ = cached_download(url)
def test_file_not_found(self): # Valid revision (None) but missing file. url = hf_hub_url(MODEL_ID, filename="missing.bin") with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): _ = cached_download(url)
def test_bogus_url(self): url = "https://bogus" with self.assertRaisesRegex(ValueError, "Connection error"): _ = cached_download(url)