def _get_repo_url_from_name( repo_name: str, organization: Optional[str] = None, private: bool = None, use_auth_token: Optional[Union[bool, str]] = None, ) -> str: if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token: token = HfFolder.get_token() if token is None: raise ValueError( "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " "token as the `use_auth_token` argument.") else: token = None # Special provision for the test endpoint (CI) return create_repo( token, repo_name, organization=organization, private=private, repo_type=None, exist_ok=True, )
def has_file( path_or_repo: Union[str, os.PathLike], filename: str, revision: Optional[str] = None, mirror: Optional[str] = None, proxies: Optional[Dict[str, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, ): """ Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders. <Tip warning={false}> This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for this repo, but will return False for regular connection errors. </Tip> """ if os.path.isdir(path_or_repo): return os.path.isfile(os.path.join(path_or_repo, filename)) url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror) headers = {"user-agent": http_user_agent()} if isinstance(use_auth_token, str): headers["authorization"] = f"Bearer {use_auth_token}" elif use_auth_token: token = HfFolder.get_token() if token is None: raise EnvironmentError( "You specified use_auth_token=True, but a huggingface token was not found." ) headers["authorization"] = f"Bearer {token}" r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10) try: _raise_for_status(r) return True except RepositoryNotFoundError as e: logger.error(e) raise EnvironmentError( f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'." ) except RevisionNotFoundError as e: logger.error(e) raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " "model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions." ) except requests.HTTPError: # We return false for EntryNotFoundError (logical) as well as any connection error. return False
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: username = whoami(token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}"
def get_list_of_files( path_or_repo: Union[str, os.PathLike], revision: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, local_files_only: bool = False, ) -> List[str]: """ Gets the list of files inside `path_or_repo`. Args: path_or_repo (`str` or `os.PathLike`): Can be either the id of a repo on huggingface.co or a path to a *directory*. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only rely on local files and not to attempt to download any files. <Tip warning={true}> This API is not optimized, so calling it a lot may result in connection errors. </Tip> Returns: `List[str]`: The list of files available in `path_or_repo`. """ path_or_repo = str(path_or_repo) # If path_or_repo is a folder, we just return what is inside (subdirectories included). if os.path.isdir(path_or_repo): list_of_files = [] for path, dir_names, file_names in os.walk(path_or_repo): list_of_files.extend([os.path.join(path, f) for f in file_names]) return list_of_files # Can't grab the files if we are on offline mode. if is_offline_mode() or local_files_only: return [] # Otherwise we grab the token and use the list_repo_files method. if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token is True: token = HfFolder.get_token() else: token = None try: return list_repo_files(path_or_repo, revision=revision, token=token) except HTTPError as e: raise ValueError( f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?" ) from e
def push_to_hf_hub( model, local_dir, repo_namespace_or_url=None, commit_message='Add model', use_auth_token=True, git_email=None, git_user=None, revision=None, model_config=None, ): if repo_namespace_or_url: repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split( '/')[-2:] else: if isinstance(use_auth_token, str): token = use_auth_token else: token = HfFolder.get_token() if token is None: raise ValueError( "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " "token as the `use_auth_token` argument.") repo_owner = HfApi().whoami(token)['name'] repo_name = Path(local_dir).name repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}' repo = Repository( local_dir, clone_from=repo_url, use_auth_token=use_auth_token, git_user=git_user, git_email=git_email, revision=revision, ) # Prepare a default model card that includes the necessary tags to enable inference. readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' with repo.commit(commit_message): # Save model weights and config. save_for_hf(model, repo.local_dir, model_config=model_config) # Save a model card if it doesn't exist. readme_path = Path(repo.local_dir) / 'README.md' if not readme_path.exists(): readme_path.write_text(readme_text) return repo.git_remote_url()
def move_cache(cache_dir=None, new_cache_dir=None, token=None): if new_cache_dir is None: new_cache_dir = TRANSFORMERS_CACHE if cache_dir is None: # Migrate from old cache in .cache/huggingface/hub old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers" if os.path.isdir(str(old_cache)): cache_dir = str(old_cache) else: cache_dir = new_cache_dir if token is None: token = HfFolder.get_token() cached_files = get_all_cached_files(cache_dir=cache_dir) print(f"Moving {len(cached_files)} files to the new cache system") hub_metadata = {} for file_info in tqdm(cached_files): url = file_info.pop("url") if url not in hub_metadata: try: hub_metadata[url] = get_hub_metadata(url, token=token) except requests.HTTPError: continue etag, commit_hash = hub_metadata[url] if etag is None or commit_hash is None: continue if file_info["etag"] != etag: # Cached file is not up to date, we just throw it as a new version will be downloaded anyway. clean_files_for(os.path.join(cache_dir, file_info["file"])) continue url_info = extract_info_from_url(url) if url_info is None: # Not a file from huggingface.co continue repo = os.path.join(new_cache_dir, url_info["repo"]) move_to_new_cache( file=os.path.join(cache_dir, file_info["file"]), repo=repo, filename=url_info["filename"], revision=url_info["revision"], etag=etag, commit_hash=commit_hash, )
def get_hub_metadata(url, token=None): """ Returns the commit hash and associated etag for a given url. """ if token is None: token = HfFolder.get_token() headers = {"user-agent": http_user_agent()} headers["authorization"] = f"Bearer {token}" r = huggingface_hub.file_download._request_with_retry( method="HEAD", url=url, headers=headers, allow_redirects=False) huggingface_hub.file_download._raise_for_status(r) commit_hash = r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT) etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get( "ETag") if etag is not None: etag = huggingface_hub.file_download._normalize_etag(etag) return etag, commit_hash
def login_to_hub() -> None: """Login to huggingface hub""" access_token = HfFolder.get_token() if access_token is not None and HfApi()._is_valid_token(access_token): logging.info("Huggingface Hub token found and valid") HfApi().set_access_token(access_token) else: subprocess.call(["huggingface-cli", "login"]) HfApi().set_access_token(HfFolder().get_token()) # check if git lfs is installed try: subprocess.call(["git", "lfs", "version"]) except FileNotFoundError: raise OSError( "Looks like you do not have git-lfs installed, please install. \ You can install from https://git-lfs.github.com/. \ Then run `git lfs install` (you only have to do this once)." )
def get_list_of_files( path_or_repo: Union[str, os.PathLike], revision: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, ) -> List[str]: """ Gets the list of files inside :obj:`path_or_repo`. Args: path_or_repo (:obj:`str` or :obj:`os.PathLike`): Can be either the id of a repo on huggingface.co or a path to a `directory`. revision (:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. use_auth_token (:obj:`str` or `bool`, `optional`): The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Returns: :obj:`List[str]`: The list of files available in :obj:`path_or_repo`. """ path_or_repo = str(path_or_repo) # If path_or_repo is a folder, we just return what is inside (subdirectories included). if os.path.isdir(path_or_repo): list_of_files = [] for path, dir_names, file_names in os.walk(path_or_repo): list_of_files.extend([os.path.join(path, f) for f in file_names]) return list_of_files # Can't grab the files if we are on offline mode. if is_offline_mode(): return [] # Otherwise we grab the token and use the model_info method. if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token is True: token = HfFolder.get_token() else: token = None model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( path_or_repo, revision=revision, token=token) return [f.rfilename for f in model_info.siblings]
def _create_repo( self, repo_id: str, private: Optional[bool] = None, use_auth_token: Optional[Union[bool, str]] = None, repo_url: Optional[str] = None, organization: Optional[str] = None, ): """ Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the token. """ if repo_url is not None: warnings.warn( "The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` " "instead.") repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "") if organization is not None: warnings.warn( "The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your " "organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)." ) if not repo_id.startswith(organization): if "/" in repo_id: repo_id = repo_id.split("/")[-1] repo_id = f"{organization}/{repo_id}" token = HfFolder.get_token( ) if use_auth_token is True else use_auth_token url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True) # If the namespace is not there, add it or `upload_file` will complain if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}": repo_id = get_full_repo_name(repo_id, token=token) return repo_id, token
def get_from_cache( url: str, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent: Union[Dict, str, None] = None, use_auth_token: Union[bool, str, None] = None, local_files_only=False, ) -> Optional[str]: """ Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the path to the cached file. Return: Local path (string) of file or if networking is off, last version of file cached on disk. Raises: In case of non-recoverable file (non-existent or inaccessible url + no cache on disk). """ if cache_dir is None: cache_dir = TRANSFORMERS_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) os.makedirs(cache_dir, exist_ok=True) headers = {"user-agent": http_user_agent(user_agent)} if isinstance(use_auth_token, str): headers["authorization"] = f"Bearer {use_auth_token}" elif use_auth_token: token = HfFolder.get_token() if token is None: raise EnvironmentError( "You specified use_auth_token=True, but a huggingface token was not found." ) headers["authorization"] = f"Bearer {token}" url_to_download = url etag = None if not local_files_only: try: r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) _raise_for_status(r) etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") # We favor a custom header indicating the etag of the linked resource, and # we fallback to the regular etag header. # If we don't have any of those, raise an error. if etag is None: raise OSError( "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." ) # In case of a redirect, # save an extra redirect on the request.get call, # and ensure we download the exact atomic version even if it changed # between the HEAD and the GET (unlikely, but hey). if 300 <= r.status_code <= 399: url_to_download = r.headers["Location"] except ( requests.exceptions.SSLError, requests.exceptions.ProxyError, RepositoryNotFoundError, EntryNotFoundError, RevisionNotFoundError, ): # Actually raise for those subclasses of ConnectionError # Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on. raise except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout): # Otherwise, our Internet connection is down. # etag is None pass filename = url_to_filename(url, etag) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) # etag is None == we don't have a connection or we passed local_files_only. # try to get the last downloaded one if etag is None: if os.path.exists(cache_path): return cache_path else: matching_files = [ file for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*") if not file.endswith(".json") and not file.endswith(".lock") ] if len(matching_files) > 0: return os.path.join(cache_dir, matching_files[-1]) else: # If files cannot be found and local_files_only=True, # the models might've been found if local_files_only=False # Notify the user about that if local_files_only: raise FileNotFoundError( "Cannot find the requested files in the cached path and outgoing traffic has been" " disabled. To enable model look-ups and downloads online, set 'local_files_only'" " to False.") else: raise ValueError( "Connection error, and we cannot find the requested files in the cached path." " Please try again or make sure your Internet connection is on." ) # From now on, etag is not None. if os.path.exists(cache_path) and not force_download: return cache_path # Prevent parallel downloads of the same file with a lock. lock_path = cache_path + ".lock" with FileLock(lock_path): # If the download just completed while the lock was activated. if os.path.exists(cache_path) and not force_download: # Even if returning early like here, the lock will be released. return cache_path if resume_download: incomplete_path = cache_path + ".incomplete" @contextmanager def _resumable_file_manager() -> "io.BufferedWriter": with open(incomplete_path, "ab") as f: yield f temp_file_manager = _resumable_file_manager if os.path.exists(incomplete_path): resume_size = os.stat(incomplete_path).st_size else: resume_size = 0 else: temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False) resume_size = 0 # Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted. with temp_file_manager() as temp_file: logger.info( f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}" ) http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers) logger.info(f"storing {url} in cache at {cache_path}") os.replace(temp_file.name, cache_path) # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it. umask = os.umask(0o666) os.umask(umask) os.chmod(cache_path, 0o666 & ~umask) logger.info(f"creating metadata file for {cache_path}") meta = {"url": url, "etag": etag} meta_path = cache_path + ".json" with open(meta_path, "w") as meta_file: json.dump(meta, meta_file) return cache_path
def get_cached_module_file( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict[str, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached Transformers module. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. module_file (`str`): The name of the module file containing the class to look for. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, will only try to load the tokenizer configuration from local files. <Tip> Passing `use_auth_token=True` is required when you want to use a private model. </Tip> Returns: `str`: The path to the module inside the cache. """ if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): submodule = "local" else: submodule = pretrained_model_name_or_path.replace("/", os.path.sep) try: # Load from URL or cache if already cached resolved_module_file = cached_file( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, ) except EnvironmentError: logger.error( f"Could not locate the {module_file} inside {pretrained_model_name_or_path}." ) raise # Check we have all the requirements in our environment modules_needed = check_imports(resolved_module_file) # Now we move the module inside our cached dynamic modules. full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule if submodule == "local": # We always copy local files (we could hash the file to see if there was a change, and give them the name of # that hash, to only copy when there is a modification but it seems overkill for now). # The only reason we do the copy is to avoid putting too many folders in sys.path. shutil.copy(resolved_module_file, submodule_path / module_file) for module_needed in modules_needed: module_needed = f"{module_needed}.py" shutil.copy( os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token is True: token = HfFolder.get_token() else: token = None commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # benefit of versioning. submodule_path = submodule_path / commit_hash full_submodule = full_submodule + os.path.sep + commit_hash create_dynamic_module(full_submodule) if not (submodule_path / module_file).exists(): shutil.copy(resolved_module_file, submodule_path / module_file) # Make sure we also have every file with relative for module_needed in modules_needed: if not (submodule_path / module_needed).exists(): get_cached_module_file( pretrained_model_name_or_path, f"{module_needed}.py", cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, ) return os.path.join(full_submodule, module_file)
def save_to_hub(self, repo_name: str, organization: Optional[str] = None, private: Optional[bool] = None, commit_message: str = "Add new SentenceTransformer model.", local_model_path: Optional[str] = None, exist_ok: bool = False, replace_model_card: bool = False): """ Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository. :param repo_name: Repository name for your model in the Hub. :param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization). :param private: Set to true, for hosting a prive model :param commit_message: Message to commit while pushing. :param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded :param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible :param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card :return: The url of the commit of your model in the given repository. """ token = HfFolder.get_token() if token is None: raise ValueError( "You must login to the Hugging Face hub on this computer by typing `transformers-cli login`." ) if '/' in repo_name: splits = repo_name.split('/', maxsplit=1) if organization is None or organization == splits[0]: organization = splits[0] repo_name = splits[1] else: raise ValueError( "You passed and invalid repository name: {}.".format( repo_name)) endpoint = "https://huggingface.co" repo_url = HfApi(endpoint=endpoint).create_repo( token, repo_name, organization=organization, private=private, repo_type=None, exist_ok=exist_ok, ) full_model_name = repo_url[len(endpoint) + 1:].strip("/") with tempfile.TemporaryDirectory() as tmp_dir: # First create the repo (and clone its content if it's nonempty). logging.info("Create repository and clone it if it exists") repo = Repository(tmp_dir, clone_from=repo_url) # If user provides local files, copy them. if local_model_path: copy_tree(local_model_path, tmp_dir) else: # Else, save model directly into local repo. create_model_card = replace_model_card or not os.path.exists( os.path.join(tmp_dir, 'README.md')) self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card) #Find files larger 5M and track with git-lfs large_files = [] for root, dirs, files in os.walk(tmp_dir): for filename in files: file_path = os.path.join(root, filename) rel_path = os.path.relpath(file_path, tmp_dir) if os.path.getsize(file_path) > (5 * 1024 * 1024): large_files.append(rel_path) if len(large_files) > 0: logging.info("Track files with git lfs: {}".format( ", ".join(large_files))) repo.lfs_track(large_files) logging.info("Push model to the hub. This might take a while") push_return = repo.push_to_hub(commit_message=commit_message) def on_rm_error(func, path, exc_info): # path contains the path of the file that couldn't be removed # let's just assume that it's read-only and unlink it. try: os.chmod(path, stat.S_IWRITE) os.unlink(path) except: pass # Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted # Hence, try to set write permissions on error try: for f in os.listdir(tmp_dir): shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error) except Exception as e: logging.warning("Error when deleting temp folder: {}".format( str(e))) pass return push_return
def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: """Save model and its configuration on HF hub >>> from doctr.models import login_to_hub, push_to_hf_hub >>> from doctr.models.recognition import crnn_mobilenet_v3_small >>> login_to_hub() >>> model = crnn_mobilenet_v3_small(pretrained=True) >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') Args: model: TF or PyTorch model to be saved model_name: name of the model which is also the repository name task: task name **kwargs: keyword arguments for push_to_hf_hub """ run_config = kwargs.get("run_config", None) arch = kwargs.get("arch", None) if run_config is None and arch is None: raise ValueError("run_config or arch must be specified") if task not in [ "classification", "detection", "recognition", "obj_detection" ]: raise ValueError( "task must be one of classification, detection, recognition, obj_detection" ) # default readme readme = textwrap.dedent(f""" --- language: en --- <p align="center"> <img src="https://github.com/mindee/doctr/releases/download/v0.3.1/Logo_doctr.gif" width="60%"> </p> **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch** ## Task: {task} https://github.com/mindee/doctr ### Example usage: ```python >>> from doctr.io import DocumentFile >>> from doctr.models import ocr_predictor, from_hub >>> img = DocumentFile.from_images(['<image_path>']) >>> # Load your model from the hub >>> model = from_hub('mindee/my-model') >>> # Pass it to the predictor >>> # If your model is a recognition model: >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', >>> reco_arch=model, >>> pretrained=True) >>> # If your model is a detection model: >>> predictor = ocr_predictor(det_arch=model, >>> reco_arch='crnn_mobilenet_v3_small', >>> pretrained=True) >>> # Get your predictions >>> res = predictor(img) ``` """) # add run configuration to readme if available if run_config is not None: arch = run_config.arch readme += textwrap.dedent(f"""### Run Configuration \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}""" ) if arch not in AVAILABLE_ARCHS[task]: # type: ignore raise ValueError(f"Architecture: {arch} for task: {task} not found.\ \nAvailable architectures: {AVAILABLE_ARCHS}") commit_message = f"Add {model_name} model" local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name) repo_url = HfApi().create_repo(model_name, token=HfFolder.get_token(), exist_ok=False) repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True) with repo.commit(commit_message): _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task) readme_path = Path(repo.local_dir) / "README.md" readme_path.write_text(readme) repo.git_push()
def push_to_hf( repo_name: str, serialization_dir: Optional[Union[str, PathLike]] = None, archive_path: Optional[Union[str, PathLike]] = None, organization: Optional[str] = None, commit_message: str = "Update repository", local_repo_path: Union[str, PathLike] = "hub", use_auth_token: Union[bool, str] = True, ) -> str: """Pushes model and related files to the Hugging Face Hub ([hf.co](https://hf.co/)) # Parameters repo_name: `str` Name of the repository in the Hugging Face Hub. serialization_dir : `Union[str, PathLike]`, optional (default = `None`) Full path to a directory with the serialized model. archive_path : `Union[str, PathLike]`, optional (default = `None`) Full path to the zipped model (e.g. model/model.tar.gz). Use `serialization_dir` if possible. organization : `Optional[str]`, optional (default = `None`) Name of organization to which the model should be uploaded. commit_message: `str` (default=`Update repository`) Commit message to use for the push. local_repo_path : `Union[str, Path]`, optional (default=`hub`) Local directory where the repository will be saved. use_auth_token (``str`` or ``bool``, `optional`, defaults ``True``): huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate against the Hugging Face Hub (useful from Google Colab for instance). It's automatically retrieved if you've done `huggingface-cli login` before. """ if serialization_dir is not None: working_dir = Path(serialization_dir) if archive_path is not None: raise ValueError( "serialization_dir and archive_path are mutually exclusive, please just use one." ) if not working_dir.exists() or not working_dir.is_dir(): raise ValueError( f"Can't find path: {serialization_dir}, please point" "to a directory with the serialized model.") elif archive_path is not None: working_dir = Path(archive_path) if (not working_dir.exists() or not zipfile.is_zipfile(working_dir) and not tarfile.is_tarfile(working_dir)): raise ValueError( f"Can't find path: {archive_path}, please point to a .tar.gz archive" "or to a directory with the serialized model.") else: logging.info( "Using the archive_path is discouraged. Using the serialization_dir" "will also upload metrics and TensorBoard traces to the Hugging Face Hub." ) else: raise ValueError( "please specify either serialization_dir or archive_path") info_msg = f"Preparing repository '{use_auth_token}'" if isinstance(use_auth_token, str): huggingface_token = use_auth_token elif use_auth_token: huggingface_token = HfFolder.get_token() # Create the repo (or clone its content if it's nonempty) api = HfApi() repo_url = api.create_repo( name=repo_name, token=huggingface_token, organization=organization, private=False, exist_ok=True, ) repo_local_path = Path(local_repo_path) / repo_name repo = Repository(repo_local_path, clone_from=repo_url, use_auth_token=use_auth_token) repo.git_pull(rebase=True) # Model file should be tracked with Git LFS repo.lfs_track(["*.th"]) info_msg = f"Preparing repository '{repo_name}'" if organization is not None: info_msg += f" ({organization})" logging.info(info_msg) # Extract information from either serializable directory or a # .tar.gz file if serialization_dir is not None: for filename in working_dir.iterdir(): _copy_allowed_file(Path(filename), repo_local_path) else: with tempfile.TemporaryDirectory() as temp_dir: extracted_dir = Path( cached_path(working_dir, temp_dir, extract_archive=True)) for filename in extracted_dir.iterdir(): _copy_allowed_file(Path(filename), repo_local_path) _create_model_card(repo_local_path) logging.info(f"Pushing repo {repo_name} to the Hugging Face Hub") repo.push_to_hub(commit_message=commit_message) logging.info(f"View your model in {repo_url}") return repo_url
def push_to_huggingface_hub(self): """Creates a downstream repository on the Hub and pushes training artifacts to it.""" if self.args.hf_hub_org.lower() != "none": organization = self.args.hf_hub_org else: organization = os.environ.get("HF_USERNAME") huggingface_token = HfFolder.get_token() print(f"[Runner] - Organisation to push fine-tuned model to: {organization}") # Extract upstream repository metadata if self.args.hub == "huggingface": model_info = HfApi().model_info(self.args.upstream, token=huggingface_token) downstream_model_id = model_info.sha # Exclude "/" characters from downstream repo ID upstream_model_id = model_info.modelId.replace("/", "__") else: upstream_model_id = self.args.upstream.replace("/", "__") downstream_model_id = str(uuid.uuid4())[:8] repo_name = f"{upstream_model_id}__{downstream_model_id}" # Create downstream repo on the Hub repo_url = HfApi().create_repo( token=huggingface_token, name=repo_name, organization=organization, exist_ok=True, private=False, ) print(f"[Runner] - Created Hub repo: {repo_url}") # Download repo HF_HUB_DIR = "hf_hub" REPO_ROOT_DIR = os.path.join(self.args.expdir, HF_HUB_DIR, repo_name) REPO_TASK_DIR = os.path.join(REPO_ROOT_DIR, self.args.downstream, self.args.expname) print(f"[Runner] - Cloning Hub repo to {REPO_ROOT_DIR}") model_repo = Repository( local_dir=REPO_ROOT_DIR, clone_from=repo_url, use_auth_token=huggingface_token ) # Pull latest changes if they exist model_repo.git_pull() # Copy checkpoints, tensorboard logs, and args / configs # Note that this copies all files from the experiment directory, # including those from multiple runs shutil.copytree(self.args.expdir, REPO_TASK_DIR, dirs_exist_ok=True, ignore=shutil.ignore_patterns(HF_HUB_DIR)) # By default we use model.ckpt in the PreTrainedModel interface, so # rename the best checkpoint to match this convention checkpoints = list(Path(REPO_TASK_DIR).glob("*best*.ckpt")) if len(checkpoints) == 0: print("[Runner] - Did not find a best checkpoint! Using the final checkpoint instead ...") CKPT_PATH = ( os.path.join(REPO_TASK_DIR, f"states-{self.config['runner']['total_steps']}.ckpt") ) elif len(checkpoints) > 1: print(f"[Runner] - More than one best checkpoint found! Using {checkpoints[0]} as default ...") CKPT_PATH = checkpoints[0] else: print(f"[Runner] - Found best checkpoint {checkpoints[0]}!") CKPT_PATH = checkpoints[0] shutil.move(CKPT_PATH, os.path.join(REPO_TASK_DIR, "model.ckpt")) model_repo.lfs_track("*.ckpt") # Write model card self._create_model_card(REPO_ROOT_DIR) # Push everything to the Hub print("[Runner] - Pushing model files to the Hub ...") model_repo.push_to_hub() print("[Runner] - Training run complete!")