def snapshot_download(repo_id: str, revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Union[Dict, str, None] = None, ignore_files: Optional[List[str]] = None) -> str: """ Method derived from huggingface_hub. Adds a new parameters 'ignore_files', which allows to ignore certain files / file-patterns """ if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) _api = HfApi() model_info = _api.model_info(repo_id=repo_id, revision=revision) storage_folder = os.path.join( cache_dir, repo_id.replace("/", REPO_ID_SEPARATOR) + "." + model_info.sha) for model_file in model_info.siblings: if ignore_files is not None: skip_download = False for pattern in ignore_files: if fnmatch.fnmatch(model_file.rfilename, pattern): skip_download = True break if skip_download: continue url = hf_hub_url(repo_id, filename=model_file.rfilename, revision=model_info.sha) relative_filepath = os.path.join(*model_file.rfilename.split("/")) # Create potential nested dir nested_dirname = os.path.dirname( os.path.join(storage_folder, relative_filepath)) os.makedirs(nested_dirname, exist_ok=True) path = cached_download( url, cache_dir=storage_folder, force_filename=relative_filepath, library_name=library_name, library_version=library_version, user_agent=user_agent, ) if os.path.exists(path + ".lock"): os.remove(path + ".lock") return storage_folder
def create_repo( hf_api: HfApi, name: str, token: Optional[str] = None, organization: Optional[str] = None, private: Optional[bool] = None, repo_type: Optional[str] = None, exist_ok: Optional[bool] = False, space_sdk: Optional[str] = None, ) -> str: """ The huggingface_hub.HfApi.create_repo parameters changed in 0.5.0 and some of them were deprecated. This function checks the huggingface_hub version to call the right parameters. Args: hf_api (`huggingface_hub.HfApi`): Hub client name (`str`): name of the repository (without the namespace) token (`str`, *optional*): user or organization token. Defaults to None. organization (`str`, *optional*): namespace for the repository: the username or organization name. By default it uses the namespace associated to the token used. private (`bool`, *optional*): Whether the model repo should be private. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. exist_ok (`bool`, *optional*, defaults to `False`): If `True`, do not raise an error if repo already exists. space_sdk (`str`, *optional*): Choice of SDK to use if repo_type is "space". Can be "streamlit", "gradio", or "static". Returns: `str`: URL to the newly created repo. """ if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"): return hf_api.create_repo( name=name, organization=organization, token=token, private=private, repo_type=repo_type, exist_ok=exist_ok, space_sdk=space_sdk, ) else: # the `organization` parameter is deprecated in huggingface_hub>=0.5.0 return hf_api.create_repo( repo_id=f"{organization}/{name}", token=token, private=private, repo_type=repo_type, exist_ok=exist_ok, space_sdk=space_sdk, )
def setUpClass(cls): """ Share this valid token in all tests below. """ cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = TOKEN cls._api.set_access_token(TOKEN)
def test_push_to_hub_model_kwargs(self): REPO_NAME = repo_name("PUSH_TO_HUB") model = self.model_init() model = self.model_fit(model) push_to_hub_keras( model, repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}", api_endpoint=ENDPOINT_STAGING, use_auth_token=self._token, git_user="******", git_email="*****@*****.**", config={ "num": 7, "act": "gelu_fast" }, include_optimizer=True, save_traces=False, ) model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info( f"{USER}/{REPO_NAME}", ) self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}") from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertRaises(ValueError, msg="Exception encountered when calling layer*") self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
def get_type(model_id): info = HfApi().model_info(repo_id=model_id) if info.config: if "speechbrain" in info.config: return ModelType(info.config["speechbrain"]["interface"].upper()) else: raise ValueError("speechbrain not in config.json") raise ValueError("no config.json in repository")
def upload(self, hf_username, model_name, token): if token is None: token = getpass("Enter your HuggingFace access token") if Path("./model").exists(): shutil.rmtree("./model") HfApi.set_access_token(token) model_url = HfApi().create_repo(token=token, name=model_name, exist_ok=True) model_repo = Repository( "./model", clone_from=model_url, use_auth_token=token, git_email=f"{hf_username}@users.noreply.huggingface.co", git_user=hf_username, ) readme_txt = f""" --- language: "en" thumbnail: "Keywords to Sentences" tags: - keytotext - k2t - Keywords to Sentences model-index: - name: {model_name} --- Idea is to build a model which will take keywords as inputs and generate sentences as outputs. Potential use case can include: - Marketing - Search Engine Optimization - Topic generation etc. - Fine tuning of topic modeling models """.strip() (Path(model_repo.local_dir) / "README.md").write_text(readme_txt) self.save_model() commit_url = model_repo.push_to_hub() print("Check out your model at:") print(f"https://huggingface.co/{hf_username}/{model_name}")
def login_to_hub() -> None: """Login to huggingface hub""" access_token = HfFolder.get_token() if access_token is not None and HfApi()._is_valid_token(access_token): logging.info("Huggingface Hub token found and valid") HfApi().set_access_token(access_token) else: subprocess.call(["huggingface-cli", "login"]) HfApi().set_access_token(HfFolder().get_token()) # check if git lfs is installed try: subprocess.call(["git", "lfs", "version"]) except FileNotFoundError: raise OSError( "Looks like you do not have git-lfs installed, please install. \ You can install from https://git-lfs.github.com/. \ Then run `git lfs install` (you only have to do this once)." )
def __post_init__(self): # Infer default license from the checkpoint used, if possible. if self.license is None and not is_offline_mode() and self.finetuned_from is not None: try: model_info = HfApi().model_info(self.finetuned_from) for tag in model_info.tags: if tag.startswith("license:"): self.license = tag[8:] except requests.exceptions.HTTPError: pass
def list_adapters(source: str = None, model_name: str = None) -> List[AdapterInfo]: """ Retrieves a list of all publicly available adapters on AdapterHub.ml or on huggingface.co. Args: source (str, optional): Identifier of the source(s) from where to get adapters. Can be either: - "ah": search on AdapterHub.ml. - "hf": search on HuggingFace model hub (huggingface.co). - None (default): search on all sources model_name (str, optional): If specified, only returns adapters trained for the model with this identifier. """ adapters = [] if source == "ah" or source is None: try: all_ah_adapters_file = download_cached(ADAPTER_HUB_ALL_FILE) except requests.exceptions.HTTPError: raise EnvironmentError( "Unable to load list of adapters from AdapterHub.ml. The service might be temporarily unavailable." ) with open(all_ah_adapters_file, "r") as f: all_ah_adapters_data = json.load(f) adapters += [AdapterInfo(**info) for info in all_ah_adapters_data] if source == "hf" or source is None: if "fetch_config" in inspect.signature(HfApi.list_models).parameters: kwargs = {"full": True, "fetch_config": True} else: logger.warning( "Using old version of huggingface-hub package for fetching. Please upgrade to latest version for accurate results." ) kwargs = {"full": True} all_hf_adapters_data = HfApi().list_models( filter="adapter-transformers", **kwargs) for model_info in all_hf_adapters_data: adapter_info = AdapterInfo( source="hf", adapter_id=model_info.modelId, model_name=model_info.config.get("adapter_transformers", {}).get("model_name") if model_info.config else None, username=model_info.modelId.split("/")[0], sha1_checksum=model_info.sha, ) adapters.append(adapter_info) if model_name is not None: adapters = [ adapter for adapter in adapters if adapter.model_name == model_name ] return adapters
def push_to_hf_hub( model, local_dir, repo_namespace_or_url=None, commit_message='Add model', use_auth_token=True, git_email=None, git_user=None, revision=None, model_config=None, ): if repo_namespace_or_url: repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split( '/')[-2:] else: if isinstance(use_auth_token, str): token = use_auth_token else: token = HfFolder.get_token() if token is None: raise ValueError( "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " "token as the `use_auth_token` argument.") repo_owner = HfApi().whoami(token)['name'] repo_name = Path(local_dir).name repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}' repo = Repository( local_dir, clone_from=repo_url, use_auth_token=use_auth_token, git_user=git_user, git_email=git_email, revision=revision, ) # Prepare a default model card that includes the necessary tags to enable inference. readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' with repo.commit(commit_message): # Save model weights and config. save_for_hf(model, repo.local_dir, model_config=model_config) # Save a model card if it doesn't exist. readme_path = Path(repo.local_dir) / 'README.md' if not readme_path.exists(): readme_path.write_text(readme_text) return repo.git_remote_url()
def delete_repo( hf_api: HfApi, name: str, token: Optional[str] = None, organization: Optional[str] = None, repo_type: Optional[str] = None, ) -> str: """ The huggingface_hub.HfApi.delete_repo parameters changed in 0.5.0 and some of them were deprecated. This function checks the huggingface_hub version to call the right parameters. Args: hf_api (`huggingface_hub.HfApi`): Hub client name (`str`): name of the repository (without the namespace) token (`str`, *optional*): user or organization token. Defaults to None. organization (`str`, *optional*): namespace for the repository: the username or organization name. By default it uses the namespace associated to the token used. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to a model. Default is `None`. Returns: `str`: URL to the newly created repo. """ if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"): return hf_api.delete_repo( name=name, organization=organization, token=token, repo_type=repo_type, ) else: # the `organization` parameter is deprecated in huggingface_hub>=0.5.0 return hf_api.delete_repo( repo_id=f"{organization}/{name}", token=token, repo_type=repo_type, )
def __init__(self, dataset=None, device='cpu', extract='tokens', randomize=True, remove_stopwords=True, lemmatizer=None): self.dataset = dataset self.device = device self.extract = extract self.randomize = randomize self.remove_stopwords = remove_stopwords self.model = None self.tokenizer = None self.interpreter = None self.lemmatizer = WordNetLemmatizer( ).lemmatize if lemmatizer is None else lemmatizer if dataset is not None: api = HfApi() # find huggingface model to provide rationalized output modelIds = api.list_models(filter=("pytorch", "dataset:" + dataset, "sibyl")) if modelIds: modelId = getattr(modelIds[0], 'modelId') print('Using ' + modelId + ' to rationalize keyphrase selections.') self.model = AutoModelForSequenceClassification.from_pretrained( modelId).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(modelId) self.interpreter = SequenceClassificationExplainer( self.model, self.tokenizer) else: self.extract = "concepts" self.stops = stopwords.words( 'english') if self.remove_stopwords else []
def test_push_to_hub_model_card_plot_false(self): REPO_NAME = repo_name("PUSH_TO_HUB") model = self.model_init() model = self.model_fit(model) push_to_hub_keras( model, repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}", api_endpoint=ENDPOINT_STAGING, use_auth_token=self._token, git_user="******", git_email="*****@*****.**", plot_model=False, ) model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info( f"{USER}/{REPO_NAME}", ) self.assertFalse( "model.png" in [f.rfilename for f in model_info.siblings]) self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
def get_list_of_files( path_or_repo: Union[str, os.PathLike], revision: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, ) -> List[str]: """ Gets the list of files inside :obj:`path_or_repo`. Args: path_or_repo (:obj:`str` or :obj:`os.PathLike`): Can be either the id of a repo on huggingface.co or a path to a `directory`. revision (:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. use_auth_token (:obj:`str` or `bool`, `optional`): The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Returns: :obj:`List[str]`: The list of files available in :obj:`path_or_repo`. """ path_or_repo = str(path_or_repo) # If path_or_repo is a folder, we just return what is inside (subdirectories included). if os.path.isdir(path_or_repo): list_of_files = [] for path, dir_names, file_names in os.walk(path_or_repo): list_of_files.extend([os.path.join(path, f) for f in file_names]) return list_of_files # Can't grab the files if we are on offline mode. if is_offline_mode(): return [] # Otherwise we grab the token and use the model_info method. if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token is True: token = HfFolder.get_token() else: token = None model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( path_or_repo, revision=revision, token=token) return [f.rfilename for f in model_info.siblings]
def resolve(model_id: str) -> [str, str]: try: info = HfApi().model_info(model_id) except Exception as e: raise ValueError( f"The hub has no information on {model_id}, does it exist: {e}" ) try: task = info.pipeline_tag except Exception: raise ValueError( f"The hub has no `pipeline_tag` on {model_id}, you can set it in the `README.md` yaml header" ) try: framework = info.library_name except Exception: raise ValueError( f"The hub has no `library_name` on {model_id}, you can set it in the `README.md` yaml header" ) return task, framework.replace("-", "_")
def test_push_to_hub(self): REPO_NAME = repo_name("PUSH_TO_HUB") model = self.model_init() model.build((None, 2)) push_to_hub_keras( model, repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}", api_endpoint=ENDPOINT_STAGING, use_auth_token=self._token, git_user="******", git_email="*****@*****.**", config={ "num": 7, "act": "gelu_fast" }, ) model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info( f"{USER}/{REPO_NAME}", ) self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}") self._api.delete_repo(name=f"{REPO_NAME}", token=self._token)
def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInfo]: """ Retrieves information about a specific adapter. Args: adapter_id (str): The identifier of the adapter to retrieve. source (str, optional): Identifier of the source(s) from where to get adapters. Can be either: - "ah": search on AdapterHub.ml. - "hf": search on HuggingFace model hub (huggingface.co). Returns: AdapterInfo: The adapter information or None if the adapter was not found. """ if source == "ah": if adapter_id.startswith("@"): adapter_id = adapter_id[1:] try: data = http_get_json(f"/adapters/{adapter_id}.json") return AdapterInfo(**data["info"]) except EnvironmentError: return None elif source == "hf": try: model_info = HfApi().model_info(adapter_id) return AdapterInfo( source="hf", adapter_id=model_info.modelId, model_name=model_info.config.get("adapter_transformers", {}).get("model_name") if model_info.config else None, username=model_info.modelId.split("/")[0], sha1_checksum=model_info.sha, ) except requests.exceptions.HTTPError: return None else: raise ValueError("Please specify either 'ah' or 'hf' as source.")
def test_override_tensorboard(self): REPO_NAME = repo_name("PUSH_TO_HUB") with tempfile.TemporaryDirectory() as tmpdirname: os.makedirs(f"{tmpdirname}/tb_log_dir") with open(f"{tmpdirname}/tb_log_dir/tensorboard.txt", "w") as fp: fp.write("Keras FTW") model = self.model_init() model.build((None, 2)) push_to_hub_keras( model, repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}", log_dir=f"{tmpdirname}/tb_log_dir", api_endpoint=ENDPOINT_STAGING, use_auth_token=self._token, git_user="******", git_email="*****@*****.**", ) os.makedirs(f"{tmpdirname}/tb_log_dir2") with open(f"{tmpdirname}/tb_log_dir2/override.txt", "w") as fp: fp.write("Keras FTW") push_to_hub_keras( model, repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}", log_dir=f"{tmpdirname}/tb_log_dir2", api_endpoint=ENDPOINT_STAGING, use_auth_token=self._token, git_user="******", git_email="*****@*****.**", ) model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info( f"{USER}/{REPO_NAME}", ) self.assertTrue("logs/override.txt" in [f.rfilename for f in model_info.siblings]) self.assertFalse("logs/tensorboard.txt" in [f.rfilename for f in model_info.siblings]) self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
def test_full_deserialization_hub(self): # Check we can read this file. # This used to fail because of BufReader that would fail because the # file exceeds the buffer capacity api = HfApi() not_loadable = [] invalid_pre_tokenizer = [] # models = api.list_models(filter="transformers") # for model in tqdm.tqdm(models): # model_id = model.modelId # for model_file in model.siblings: # filename = model_file.rfilename # if filename == "tokenizer.json": # all_models.append((model_id, filename)) all_models = [("HueyNemud/das22-10-camembert_pretrained", "tokenizer.json")] for (model_id, filename) in tqdm.tqdm(all_models): tokenizer_file = cached_download( hf_hub_url(model_id, filename=filename)) is_ok = check(tokenizer_file) if not is_ok: print(f"{model_id} is affected by no type") invalid_pre_tokenizer.append(model_id) try: Tokenizer.from_file(tokenizer_file) except Exception as e: print(f"{model_id} is not loadable: {e}") not_loadable.append(model_id) except: print(f"{model_id} is not loadable: Rust error") not_loadable.append(model_id) self.assertEqual(invalid_pre_tokenizer, []) self.assertEqual(not_loadable, [])
def save_to_hub(self, repo_name: str, organization: Optional[str] = None, private: Optional[bool] = None, commit_message: str = "Add new SentenceTransformer model.", local_model_path: Optional[str] = None, exist_ok: bool = False, replace_model_card: bool = False): """ Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository. :param repo_name: Repository name for your model in the Hub. :param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization). :param private: Set to true, for hosting a prive model :param commit_message: Message to commit while pushing. :param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded :param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible :param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card :return: The url of the commit of your model in the given repository. """ token = HfFolder.get_token() if token is None: raise ValueError( "You must login to the Hugging Face hub on this computer by typing `transformers-cli login`." ) if '/' in repo_name: splits = repo_name.split('/', maxsplit=1) if organization is None or organization == splits[0]: organization = splits[0] repo_name = splits[1] else: raise ValueError( "You passed and invalid repository name: {}.".format( repo_name)) endpoint = "https://huggingface.co" repo_url = HfApi(endpoint=endpoint).create_repo( token, repo_name, organization=organization, private=private, repo_type=None, exist_ok=exist_ok, ) full_model_name = repo_url[len(endpoint) + 1:].strip("/") with tempfile.TemporaryDirectory() as tmp_dir: # First create the repo (and clone its content if it's nonempty). logging.info("Create repository and clone it if it exists") repo = Repository(tmp_dir, clone_from=repo_url) # If user provides local files, copy them. if local_model_path: copy_tree(local_model_path, tmp_dir) else: # Else, save model directly into local repo. create_model_card = replace_model_card or not os.path.exists( os.path.join(tmp_dir, 'README.md')) self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card) #Find files larger 5M and track with git-lfs large_files = [] for root, dirs, files in os.walk(tmp_dir): for filename in files: file_path = os.path.join(root, filename) rel_path = os.path.relpath(file_path, tmp_dir) if os.path.getsize(file_path) > (5 * 1024 * 1024): large_files.append(rel_path) if len(large_files) > 0: logging.info("Track files with git lfs: {}".format( ", ".join(large_files))) repo.lfs_track(large_files) logging.info("Push model to the hub. This might take a while") push_return = repo.push_to_hub(commit_message=commit_message) def on_rm_error(func, path, exc_info): # path contains the path of the file that couldn't be removed # let's just assume that it's read-only and unlink it. try: os.chmod(path, stat.S_IWRITE) os.unlink(path) except: pass # Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted # Hence, try to set write permissions on error try: for f in os.listdir(tmp_dir): shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error) except Exception as e: logging.warning("Error when deleting temp folder: {}".format( str(e))) pass return push_return
class TestPushToHub(TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING) @classmethod def setUpClass(cls): """ Share this valid token in all tests below. """ cls._hf_folder_patch = patch( "huggingface_hub.hf_api.HfFolder.path_token", TOKEN_PATH_STAGING, ) cls._hf_folder_patch.start() cls._token = cls._api.login(username=USER, password=PASS) HfFolder.save_token(cls._token) @classmethod def tearDownClass(cls) -> None: HfFolder.delete_token() cls._hf_folder_patch.stop() def test_push_dataset_dict_to_hub_no_token(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there is a single file on the repository that has the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset")) self.assertListEqual(files, [ ".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json" ]) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], repo_type="dataset") def test_push_dataset_dict_to_hub_name_without_namespace(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there is a single file on the repository that has the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset")) self.assertListEqual(files, [ ".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json" ]) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], repo_type="dataset") def test_push_dataset_dict_to_hub_private(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token, private=True) hub_ds = load_dataset(ds_name, download_mode="force_redownload", use_auth_token=self._token) self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there is a single file on the repository that has the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual(files, [ ".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json" ]) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_dict_to_hub(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there is a single file on the repository that has the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual(files, [ ".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json" ]) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_dict_to_hub_multiple_files(self): ds = Dataset.from_dict({ "x": list(range(1000)), "y": list(range(1000)) }) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token, shard_size=500 << 5) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there are two files on the repository that have the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual( files, [ ".gitattributes", "data/train-00000-of-00002.parquet", "data/train-00001-of-00002.parquet", "dataset_infos.json", ], ) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_dict_to_hub_overwrite_files(self): ds = Dataset.from_dict({ "x": list(range(1000)), "y": list(range(1000)) }) ds2 = Dataset.from_dict({"x": list(range(100)), "y": list(range(100))}) local_ds = DatasetDict({"train": ds, "random": ds2}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" # Push to hub two times, but the second time with a larger amount of files. # Verify that the new files contain the correct dataset. try: local_ds.push_to_hub(ds_name, token=self._token) with tempfile.TemporaryDirectory() as tmp: # Add a file starting with "data" to ensure it doesn't get deleted. path = Path(tmp) / "datafile.txt" with open(path, "w") as f: f.write("Bogus file") self._api.upload_file(str(path), path_in_repo="datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) local_ds.push_to_hub(ds_name, token=self._token, shard_size=500 << 5) # Ensure that there are two files on the repository that have the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual( files, [ ".gitattributes", "data/random-00000-of-00001.parquet", "data/train-00000-of-00002.parquet", "data/train-00001-of-00002.parquet", "datafile.txt", "dataset_infos.json", ], ) self._api.delete_file("datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") # Push to hub two times, but the second time with fewer files. # Verify that the new files contain the correct dataset and that non-necessary files have been deleted. try: local_ds.push_to_hub(ds_name, token=self._token, shard_size=500 << 5) with tempfile.TemporaryDirectory() as tmp: # Add a file starting with "data" to ensure it doesn't get deleted. path = Path(tmp) / "datafile.txt" with open(path, "w") as f: f.write("Bogus file") self._api.upload_file(str(path), path_in_repo="datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) local_ds.push_to_hub(ds_name, token=self._token) # Ensure that there are two files on the repository that have the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual( files, [ ".gitattributes", "data/random-00000-of-00001.parquet", "data/train-00000-of-00001.parquet", "datafile.txt", "dataset_infos.json", ], ) # Keeping the "datafile.txt" breaks the load_dataset to think it's a text-based dataset self._api.delete_file("datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_to_hub(self): local_ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, split="train", token=self._token) local_ds_dict = {"train": local_ds} hub_ds_dict = load_dataset(ds_name, download_mode="force_redownload") self.assertListEqual(list(local_ds_dict.keys()), list(hub_ds_dict.keys())) for ds_split_name in local_ds_dict.keys(): local_ds = local_ds_dict[ds_split_name] hub_ds = hub_ds_dict[ds_split_name] self.assertListEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds.features.keys()), list(hub_ds.features.keys())) self.assertDictEqual(local_ds.features, hub_ds.features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_to_hub_custom_features(self): features = Features({ "x": Value("int64"), "y": ClassLabel(names=["neg", "pos"]) }) ds = Dataset.from_dict({ "x": [1, 2, 3], "y": [0, 0, 1] }, features=features) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertListEqual(ds.column_names, hub_ds["train"].column_names) self.assertListEqual(list(ds.features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(ds.features, hub_ds["train"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_dict_to_hub_custom_features(self): features = Features({ "x": Value("int64"), "y": ClassLabel(names=["neg", "pos"]) }) ds = Dataset.from_dict({ "x": [1, 2, 3], "y": [0, 0, 1] }, features=features) local_ds = DatasetDict({"test": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["test"].features.keys()), list(hub_ds["test"].features.keys())) self.assertDictEqual(local_ds["test"].features, hub_ds["test"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_to_hub_custom_splits(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: ds.push_to_hub(ds_name, split="random", token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertListEqual(ds.column_names, hub_ds["random"].column_names) self.assertListEqual(list(ds.features.keys()), list(hub_ds["random"].features.keys())) self.assertDictEqual(ds.features, hub_ds["random"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") def test_push_dataset_dict_to_hub_custom_splits(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"random": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["random"].features.keys()), list(hub_ds["random"].features.keys())) self.assertDictEqual(local_ds["random"].features, hub_ds["random"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") @unittest.skip( "This test cannot pass until iterable datasets have push to hub") def test_push_streaming_dataset_dict_to_hub(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) with tempfile.TemporaryDirectory() as tmp: local_ds.save_to_disk(tmp) local_ds = load_dataset(tmp, streaming=True) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset")
class TestPushToHub(AllenNlpTestCase): def setup_method(self): super().setup_method() self.api = HfApi(ENDPOINT_STAGING) self.token = self.api.login(username=USER, password=PASS) self.local_repo_path = self.TEST_DIR / "hub" self.clone_path = self.TEST_DIR / "hub_clone" def teardown_method(self): super().teardown_method() try: self.api.delete_repo(token=self.token, name=REPO_NAME) except requests.exceptions.HTTPError: pass try: self.api.delete_repo( token=self.token, organization=ORG_NAME, name=REPO_NAME, ) except requests.exceptions.HTTPError: pass @with_staging_testing def test_push_to_hub_archive_path(self): archive_path = self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz" url = push_to_hf( repo_name=REPO_NAME, archive_path=archive_path, local_repo_path=self.local_repo_path, use_auth_token=self.token, ) # Check that the returned commit url # actually exists. r = requests.head(url) r.raise_for_status() Repository( self.clone_path, clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}", use_auth_token=self.token, ) assert "model.th" in os.listdir(self.clone_path) shutil.rmtree(self.clone_path) @with_staging_testing def test_push_to_hub_serialization_dir(self): serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization" url = push_to_hf( repo_name=REPO_NAME, serialization_dir=serialization_dir, local_repo_path=self.local_repo_path, use_auth_token=self.token, ) # Check that the returned commit url # actually exists. r = requests.head(url) r.raise_for_status() Repository( self.clone_path, clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}", use_auth_token=self.token, ) assert "model.th" in os.listdir(self.clone_path) shutil.rmtree(self.clone_path) @with_staging_testing def test_push_to_hub_to_org(self): serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization" url = push_to_hf( repo_name=REPO_NAME, serialization_dir=serialization_dir, organization=ORG_NAME, local_repo_path=self.local_repo_path, use_auth_token=self.token, ) # Check that the returned commit url # actually exists. r = requests.head(url) r.raise_for_status() Repository( self.clone_path, clone_from=f"{ENDPOINT_STAGING}/{ORG_NAME}/{REPO_NAME}", use_auth_token=self.token, ) assert "model.th" in os.listdir(self.clone_path) shutil.rmtree(self.clone_path) @with_staging_testing def test_push_to_hub_fails_with_invalid_token(self): serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization" with pytest.raises(requests.exceptions.HTTPError): push_to_hf( repo_name=REPO_NAME, serialization_dir=serialization_dir, local_repo_path=self.local_repo_path, use_auth_token="invalid token", )
def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: """Save model and its configuration on HF hub >>> from doctr.models import login_to_hub, push_to_hf_hub >>> from doctr.models.recognition import crnn_mobilenet_v3_small >>> login_to_hub() >>> model = crnn_mobilenet_v3_small(pretrained=True) >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') Args: model: TF or PyTorch model to be saved model_name: name of the model which is also the repository name task: task name **kwargs: keyword arguments for push_to_hf_hub """ run_config = kwargs.get("run_config", None) arch = kwargs.get("arch", None) if run_config is None and arch is None: raise ValueError("run_config or arch must be specified") if task not in [ "classification", "detection", "recognition", "obj_detection" ]: raise ValueError( "task must be one of classification, detection, recognition, obj_detection" ) # default readme readme = textwrap.dedent(f""" --- language: en --- <p align="center"> <img src="https://github.com/mindee/doctr/releases/download/v0.3.1/Logo_doctr.gif" width="60%"> </p> **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch** ## Task: {task} https://github.com/mindee/doctr ### Example usage: ```python >>> from doctr.io import DocumentFile >>> from doctr.models import ocr_predictor, from_hub >>> img = DocumentFile.from_images(['<image_path>']) >>> # Load your model from the hub >>> model = from_hub('mindee/my-model') >>> # Pass it to the predictor >>> # If your model is a recognition model: >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', >>> reco_arch=model, >>> pretrained=True) >>> # If your model is a detection model: >>> predictor = ocr_predictor(det_arch=model, >>> reco_arch='crnn_mobilenet_v3_small', >>> pretrained=True) >>> # Get your predictions >>> res = predictor(img) ``` """) # add run configuration to readme if available if run_config is not None: arch = run_config.arch readme += textwrap.dedent(f"""### Run Configuration \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}""" ) if arch not in AVAILABLE_ARCHS[task]: # type: ignore raise ValueError(f"Architecture: {arch} for task: {task} not found.\ \nAvailable architectures: {AVAILABLE_ARCHS}") commit_message = f"Add {model_name} model" local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name) repo_url = HfApi().create_repo(model_name, token=HfFolder.get_token(), exist_ok=False) repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True) with repo.commit(commit_message): _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task) readme_path = Path(repo.local_dir) / "README.md" readme_path.write_text(readme) repo.git_push()
def setup_method(self): super().setup_method() self.api = HfApi(ENDPOINT_STAGING) self.token = self.api.login(username=USER, password=PASS) self.local_repo_path = self.TEST_DIR / "hub" self.clone_path = self.TEST_DIR / "hub_clone"
def push_to_hf( repo_name: str, serialization_dir: Optional[Union[str, PathLike]] = None, archive_path: Optional[Union[str, PathLike]] = None, organization: Optional[str] = None, commit_message: str = "Update repository", local_repo_path: Union[str, PathLike] = "hub", use_auth_token: Union[bool, str] = True, ) -> str: """Pushes model and related files to the Hugging Face Hub ([hf.co](https://hf.co/)) # Parameters repo_name: `str` Name of the repository in the Hugging Face Hub. serialization_dir : `Union[str, PathLike]`, optional (default = `None`) Full path to a directory with the serialized model. archive_path : `Union[str, PathLike]`, optional (default = `None`) Full path to the zipped model (e.g. model/model.tar.gz). Use `serialization_dir` if possible. organization : `Optional[str]`, optional (default = `None`) Name of organization to which the model should be uploaded. commit_message: `str` (default=`Update repository`) Commit message to use for the push. local_repo_path : `Union[str, Path]`, optional (default=`hub`) Local directory where the repository will be saved. use_auth_token (``str`` or ``bool``, `optional`, defaults ``True``): huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate against the Hugging Face Hub (useful from Google Colab for instance). It's automatically retrieved if you've done `huggingface-cli login` before. """ if serialization_dir is not None: working_dir = Path(serialization_dir) if archive_path is not None: raise ValueError( "serialization_dir and archive_path are mutually exclusive, please just use one." ) if not working_dir.exists() or not working_dir.is_dir(): raise ValueError( f"Can't find path: {serialization_dir}, please point" "to a directory with the serialized model.") elif archive_path is not None: working_dir = Path(archive_path) if (not working_dir.exists() or not zipfile.is_zipfile(working_dir) and not tarfile.is_tarfile(working_dir)): raise ValueError( f"Can't find path: {archive_path}, please point to a .tar.gz archive" "or to a directory with the serialized model.") else: logging.info( "Using the archive_path is discouraged. Using the serialization_dir" "will also upload metrics and TensorBoard traces to the Hugging Face Hub." ) else: raise ValueError( "please specify either serialization_dir or archive_path") info_msg = f"Preparing repository '{use_auth_token}'" if isinstance(use_auth_token, str): huggingface_token = use_auth_token elif use_auth_token: huggingface_token = HfFolder.get_token() # Create the repo (or clone its content if it's nonempty) api = HfApi() repo_url = api.create_repo( name=repo_name, token=huggingface_token, organization=organization, private=False, exist_ok=True, ) repo_local_path = Path(local_repo_path) / repo_name repo = Repository(repo_local_path, clone_from=repo_url, use_auth_token=use_auth_token) repo.git_pull(rebase=True) # Model file should be tracked with Git LFS repo.lfs_track(["*.th"]) info_msg = f"Preparing repository '{repo_name}'" if organization is not None: info_msg += f" ({organization})" logging.info(info_msg) # Extract information from either serializable directory or a # .tar.gz file if serialization_dir is not None: for filename in working_dir.iterdir(): _copy_allowed_file(Path(filename), repo_local_path) else: with tempfile.TemporaryDirectory() as temp_dir: extracted_dir = Path( cached_path(working_dir, temp_dir, extract_archive=True)) for filename in extracted_dir.iterdir(): _copy_allowed_file(Path(filename), repo_local_path) _create_model_card(repo_local_path) logging.info(f"Pushing repo {repo_name} to the Hugging Face Hub") repo.push_to_hub(commit_message=commit_message) logging.info(f"View your model in {repo_url}") return repo_url
class HubMixingCommonTest(unittest.TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING)
def push_to_huggingface_hub(self): """Creates a downstream repository on the Hub and pushes training artifacts to it.""" if self.args.hf_hub_org.lower() != "none": organization = self.args.hf_hub_org else: organization = os.environ.get("HF_USERNAME") huggingface_token = HfFolder.get_token() print(f"[Runner] - Organisation to push fine-tuned model to: {organization}") # Extract upstream repository metadata if self.args.hub == "huggingface": model_info = HfApi().model_info(self.args.upstream, token=huggingface_token) downstream_model_id = model_info.sha # Exclude "/" characters from downstream repo ID upstream_model_id = model_info.modelId.replace("/", "__") else: upstream_model_id = self.args.upstream.replace("/", "__") downstream_model_id = str(uuid.uuid4())[:8] repo_name = f"{upstream_model_id}__{downstream_model_id}" # Create downstream repo on the Hub repo_url = HfApi().create_repo( token=huggingface_token, name=repo_name, organization=organization, exist_ok=True, private=False, ) print(f"[Runner] - Created Hub repo: {repo_url}") # Download repo HF_HUB_DIR = "hf_hub" REPO_ROOT_DIR = os.path.join(self.args.expdir, HF_HUB_DIR, repo_name) REPO_TASK_DIR = os.path.join(REPO_ROOT_DIR, self.args.downstream, self.args.expname) print(f"[Runner] - Cloning Hub repo to {REPO_ROOT_DIR}") model_repo = Repository( local_dir=REPO_ROOT_DIR, clone_from=repo_url, use_auth_token=huggingface_token ) # Pull latest changes if they exist model_repo.git_pull() # Copy checkpoints, tensorboard logs, and args / configs # Note that this copies all files from the experiment directory, # including those from multiple runs shutil.copytree(self.args.expdir, REPO_TASK_DIR, dirs_exist_ok=True, ignore=shutil.ignore_patterns(HF_HUB_DIR)) # By default we use model.ckpt in the PreTrainedModel interface, so # rename the best checkpoint to match this convention checkpoints = list(Path(REPO_TASK_DIR).glob("*best*.ckpt")) if len(checkpoints) == 0: print("[Runner] - Did not find a best checkpoint! Using the final checkpoint instead ...") CKPT_PATH = ( os.path.join(REPO_TASK_DIR, f"states-{self.config['runner']['total_steps']}.ckpt") ) elif len(checkpoints) > 1: print(f"[Runner] - More than one best checkpoint found! Using {checkpoints[0]} as default ...") CKPT_PATH = checkpoints[0] else: print(f"[Runner] - Found best checkpoint {checkpoints[0]}!") CKPT_PATH = checkpoints[0] shutil.move(CKPT_PATH, os.path.join(REPO_TASK_DIR, "model.ckpt")) model_repo.lfs_track("*.ckpt") # Write model card self._create_model_card(REPO_ROOT_DIR) # Push everything to the Hub print("[Runner] - Pushing model files to the Hub ...") model_repo.push_to_hub() print("[Runner] - Training run complete!")
class SnapshotDownloadTests(unittest.TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING) @classmethod def setUpClass(cls): """ Share this valid token in all tests below. """ cls._token = TOKEN cls._api.set_access_token(TOKEN) @retry_endpoint def setUp(self) -> None: if os.path.exists(REPO_NAME): shutil.rmtree(REPO_NAME, onerror=set_write_permission_and_retry) logger.info(f"Does {REPO_NAME} exist: {os.path.exists(REPO_NAME)}") repo = Repository( REPO_NAME, clone_from=f"{USER}/{REPO_NAME}", use_auth_token=self._token, git_user="******", git_email="*****@*****.**", ) with repo.commit("Add file to main branch"): with open("dummy_file.txt", "w+") as f: f.write("v1") self.first_commit_hash = repo.git_head_hash() with repo.commit("Add file to main branch"): with open("dummy_file.txt", "w+") as f: f.write("v2") with open("dummy_file_2.txt", "w+") as f: f.write("v3") self.second_commit_hash = repo.git_head_hash() with repo.commit("Add file to other branch", branch="other"): with open("dummy_file_2.txt", "w+") as f: f.write("v4") self.third_commit_hash = repo.git_head_hash() def tearDown(self) -> None: self._api.delete_repo(repo_id=REPO_NAME, token=self._token) shutil.rmtree(REPO_NAME) def test_download_model(self): # Test `main` branch with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) # Test with specific revision with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision=self.first_commit_hash, cache_dir=tmpdirname, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 2) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v1") # folder name contains the revision's commit sha. self.assertTrue(self.first_commit_hash in storage_folder) def test_download_private_model(self): self._api.update_repo_visibility( token=self._token, repo_id=REPO_NAME, private=True ) # Test download fails without token with tempfile.TemporaryDirectory() as tmpdirname: with self.assertRaisesRegex( requests.exceptions.HTTPError, "404 Client Error" ): _ = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname ) # Test we can download with token from cache with tempfile.TemporaryDirectory() as tmpdirname: HfFolder.save_token(self._token) storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname, use_auth_token=True, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) # Test we can download with explicit token with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname, use_auth_token=self._token, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) self._api.update_repo_visibility( token=self._token, repo_id=REPO_NAME, private=False ) def test_download_model_local_only(self): # Test no branch specified with tempfile.TemporaryDirectory() as tmpdirname: # first download folder to cache it snapshot_download(f"{USER}/{REPO_NAME}", cache_dir=tmpdirname) # now load from cache storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, local_files_only=True, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) # Test with specific revision branch with tempfile.TemporaryDirectory() as tmpdirname: # first download folder to cache it snapshot_download( f"{USER}/{REPO_NAME}", revision="other", cache_dir=tmpdirname, ) # now load from cache storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="other", cache_dir=tmpdirname, local_files_only=True, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.third_commit_hash in storage_folder) # Test with specific revision hash with tempfile.TemporaryDirectory() as tmpdirname: # first download folder to cache it snapshot_download( f"{USER}/{REPO_NAME}", revision=self.first_commit_hash, cache_dir=tmpdirname, ) # now load from cache storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision=self.first_commit_hash, cache_dir=tmpdirname, local_files_only=True, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 2) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v1") # folder name contains the revision's commit sha. self.assertTrue(self.first_commit_hash in storage_folder) def test_download_model_local_only_multiple(self): # Test `main` branch with tempfile.TemporaryDirectory() as tmpdirname: # download both from branch and from commit snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, ) snapshot_download( f"{USER}/{REPO_NAME}", revision=self.first_commit_hash, cache_dir=tmpdirname, ) # now load from cache and make sure warning to be raised with self.assertWarns(Warning): snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, local_files_only=True, ) # cache multiple commits and make sure correct commit is taken with tempfile.TemporaryDirectory() as tmpdirname: # first download folder to cache it snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, ) # now load folder from another branch snapshot_download( f"{USER}/{REPO_NAME}", revision="other", cache_dir=tmpdirname, ) # now make sure that loading "main" branch gives correct branch storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", cache_dir=tmpdirname, local_files_only=True, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 3) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue(".gitattributes" in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the 2nd commit sha and not the 3rd self.assertTrue(self.second_commit_hash in storage_folder) def check_download_model_with_regex(self, regex, allow=True): # Test `main` branch allow_regex = regex if allow else None ignore_regex = regex if not allow else None with tempfile.TemporaryDirectory() as tmpdirname: storage_folder = snapshot_download( f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname, allow_regex=allow_regex, ignore_regex=ignore_regex, ) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) self.assertEqual(len(folder_contents), 2) self.assertTrue("dummy_file.txt" in folder_contents) self.assertTrue("dummy_file_2.txt" in folder_contents) self.assertTrue(".gitattributes" not in folder_contents) with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: contents = f.read() self.assertEqual(contents, "v2") # folder name contains the revision's commit sha. self.assertTrue(self.second_commit_hash in storage_folder) def test_download_model_with_allow_regex(self): self.check_download_model_with_regex("*.txt") def test_download_model_with_allow_regex_list(self): self.check_download_model_with_regex(["dummy_file.txt", "dummy_file_2.txt"]) def test_download_model_with_ignore_regex(self): self.check_download_model_with_regex(".gitattributes", allow=False) def test_download_model_with_ignore_regex_list(self): self.check_download_model_with_regex(["*.git*", "*.pt"], allow=False)
def setUpClass(cls): cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = cls._api.login(username=USER, password=PASS)
def main(): torch.multiprocessing.set_sharing_strategy('file_system') torchaudio.set_audio_backend('sox_io') hack_isinstance() # get config and arguments args, config, backup_files = get_downstream_args() if args.cache_dir is not None: torch.hub.set_dir(args.cache_dir) # When torch.distributed.launch is used if args.local_rank is not None: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(args.backend) if args.mode == 'train' and args.past_exp: ckpt = torch.load(args.init_ckpt, map_location='cpu') now_use_ddp = is_initialized() original_use_ddp = ckpt['Args'].local_rank is not None assert now_use_ddp == original_use_ddp, f'{now_use_ddp} != {original_use_ddp}' if now_use_ddp: now_world = get_world_size() original_world = ckpt['WorldSize'] assert now_world == original_world, f'{now_world} != {original_world}' if args.hub == "huggingface": args.from_hf_hub = True # Setup auth hf_user = os.environ.get("HF_USERNAME") hf_password = os.environ.get("HF_PASSWORD") huggingface_token = HfApi().login(username=hf_user, password=hf_password) HfFolder.save_token(huggingface_token) print(f"Logged into Hugging Face Hub with user: {hf_user}") # Save command if is_leader_process(): with open(os.path.join(args.expdir, f'args_{get_time_tag()}.yaml'), 'w') as file: yaml.dump(vars(args), file) with open(os.path.join(args.expdir, f'config_{get_time_tag()}.yaml'), 'w') as file: yaml.dump(config, file) for file in backup_files: backup(file, args.expdir) # Fix seed and make backends deterministic random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) if args.disable_cudnn: torch.backends.cudnn.enabled = False else: torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False runner = Runner(args, config) eval(f'runner.{args.mode}')()