def check_download_model_with_regex(self, regex, allow=True): # Test `main` branch allow_regex = regex if allow else None ignore_regex = regex if not allow else None with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname, allow_regex=allow_regex, ignore_regex=ignore_regex, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 2) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" not in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder)
def test_download_model(self): # Test `main` branch with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download(f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) # Test with specific revision with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision=self.first_commit_hash, cache_dir=tmpdirname, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 2) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v1") # folder name contains the revision's commit sha. self.assertTrue(self.first_commit_hash in storage_folder)
def _from_pretrained( cls, model_id, revision, cache_dir, force_download, proxies, resume_download, local_files_only, use_auth_token, **model_kwargs, ): """Here we just call from_pretrained_keras function so both the mixin and functional APIs stay in sync. TODO - Some args above aren't used since we are calling snapshot_download instead of hf_hub_download. """ if is_tf_available(): import tensorflow as tf else: raise ImportError( "Called a Tensorflow-specific function but could not import it." ) # TODO - Figure out what to do about these config values. Config is not going to be needed to load model cfg = model_kwargs.pop("config", None) # Root is either a local filepath matching model_id or a cached snapshot if not os.path.isdir(model_id): storage_folder = snapshot_download( repo_id=model_id, revision=revision, cache_dir=cache_dir ) else: storage_folder = model_id model = tf.keras.models.load_model(storage_folder, **model_kwargs) # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir. model.config = cfg return model
def test_download_model_local_only_multiple(self): # Test `main` branch with tempfile.TemporaryDirectory() as tmpdirname: # download both from branch and from commit snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, ) snapshot_download( f"{USER}/{REPO_NAME}", revision=self.first_commit_hash, cache_dir=tmpdirname, ) # now load from cache and make sure warning to be raised with self.assertWarns(Warning): snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, local_files_only=True, ) # cache multiple commits and make sure correct commit is taken with tempfile.TemporaryDirectory() as tmpdirname: # first download folder to cache it snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, ) # now load folder from another branch snapshot_download( f"{USER}/{REPO_NAME}", revision="other", cache_dir=tmpdirname, ) # now make sure that loading "main" branch gives correct branch storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, local_files_only=True, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the 2nd commit sha and not the 3rd self.assertTrue(self.second_commit_hash in storage_folder)
def test_download_private_model(self): self._api.update_repo_visibility(token=self._token, name=REPO_NAME, private=True) # Test download fails without token with tempfile.TemporaryDirectory() as tmpdirname: with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): _ = snapshot_download(f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname) # Test we can download with token from cache with tempfile.TemporaryDirectory() as tmpdirname: HfFolder.save_token(self._token) storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname, use_auth_token=True, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) # Test we can download with explicit token with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname, use_auth_token=self._token, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) self._api.update_repo_visibility(token=self._token, name=REPO_NAME, private=False)