예제 #1
0
def create_repo(
    hf_api: HfApi,
    name: str,
    token: Optional[str] = None,
    organization: Optional[str] = None,
    private: Optional[bool] = None,
    repo_type: Optional[str] = None,
    exist_ok: Optional[bool] = False,
    space_sdk: Optional[str] = None,
) -> str:
    """
    The huggingface_hub.HfApi.create_repo parameters changed in 0.5.0 and some of them were deprecated.
    This function checks the huggingface_hub version to call the right parameters.

    Args:
        hf_api (`huggingface_hub.HfApi`): Hub client
        name (`str`): name of the repository (without the namespace)
        token (`str`, *optional*): user or organization token. Defaults to None.
        organization (`str`, *optional*): namespace for the repository: the username or organization name.
            By default it uses the namespace associated to the token used.
        private (`bool`, *optional*):
            Whether the model repo should be private.
        repo_type (`str`, *optional*):
            Set to `"dataset"` or `"space"` if uploading to a dataset or
            space, `None` or `"model"` if uploading to a model. Default is
            `None`.
        exist_ok (`bool`, *optional*, defaults to `False`):
            If `True`, do not raise an error if repo already exists.
        space_sdk (`str`, *optional*):
            Choice of SDK to use if repo_type is "space". Can be
            "streamlit", "gradio", or "static".

    Returns:
        `str`: URL to the newly created repo.
    """
    if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
        return hf_api.create_repo(
            name=name,
            organization=organization,
            token=token,
            private=private,
            repo_type=repo_type,
            exist_ok=exist_ok,
            space_sdk=space_sdk,
        )
    else:  # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
        return hf_api.create_repo(
            repo_id=f"{organization}/{name}",
            token=token,
            private=private,
            repo_type=repo_type,
            exist_ok=exist_ok,
            space_sdk=space_sdk,
        )
예제 #2
0
def push_to_hf(
    repo_name: str,
    serialization_dir: Optional[Union[str, PathLike]] = None,
    archive_path: Optional[Union[str, PathLike]] = None,
    organization: Optional[str] = None,
    commit_message: str = "Update repository",
    local_repo_path: Union[str, PathLike] = "hub",
    use_auth_token: Union[bool, str] = True,
) -> str:
    """Pushes model and related files to the Hugging Face Hub ([hf.co](https://hf.co/))

    # Parameters

    repo_name: `str`
        Name of the repository in the Hugging Face Hub.

    serialization_dir : `Union[str, PathLike]`, optional (default = `None`)
        Full path to a directory with the serialized model.

    archive_path : `Union[str, PathLike]`, optional (default = `None`)
        Full path to the zipped model (e.g. model/model.tar.gz). Use `serialization_dir` if possible.

    organization : `Optional[str]`, optional (default = `None`)
        Name of organization to which the model should be uploaded.

    commit_message: `str` (default=`Update repository`)
        Commit message to use for the push.

    local_repo_path : `Union[str, Path]`, optional (default=`hub`)
        Local directory where the repository will be saved.

    use_auth_token (``str`` or ``bool``, `optional`, defaults ``True``):
        huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate
        against the Hugging Face Hub (useful from Google Colab for instance). It's automatically retrieved
        if you've done `huggingface-cli login` before.
    """

    if serialization_dir is not None:
        working_dir = Path(serialization_dir)
        if archive_path is not None:
            raise ValueError(
                "serialization_dir and archive_path are mutually exclusive, please just use one."
            )
        if not working_dir.exists() or not working_dir.is_dir():
            raise ValueError(
                f"Can't find path: {serialization_dir}, please point"
                "to a directory with the serialized model.")
    elif archive_path is not None:
        working_dir = Path(archive_path)
        if (not working_dir.exists() or not zipfile.is_zipfile(working_dir)
                and not tarfile.is_tarfile(working_dir)):
            raise ValueError(
                f"Can't find path: {archive_path}, please point to a .tar.gz archive"
                "or to a directory with the serialized model.")
        else:
            logging.info(
                "Using the archive_path is discouraged. Using the serialization_dir"
                "will also upload metrics and TensorBoard traces to the Hugging Face Hub."
            )
    else:
        raise ValueError(
            "please specify either serialization_dir or archive_path")

    info_msg = f"Preparing repository '{use_auth_token}'"
    if isinstance(use_auth_token, str):
        huggingface_token = use_auth_token
    elif use_auth_token:
        huggingface_token = HfFolder.get_token()

    # Create the repo (or clone its content if it's nonempty)
    api = HfApi()
    repo_url = api.create_repo(
        name=repo_name,
        token=huggingface_token,
        organization=organization,
        private=False,
        exist_ok=True,
    )

    repo_local_path = Path(local_repo_path) / repo_name
    repo = Repository(repo_local_path,
                      clone_from=repo_url,
                      use_auth_token=use_auth_token)
    repo.git_pull(rebase=True)

    # Model file should be tracked with Git LFS
    repo.lfs_track(["*.th"])
    info_msg = f"Preparing repository '{repo_name}'"
    if organization is not None:
        info_msg += f" ({organization})"
    logging.info(info_msg)

    # Extract information from either serializable directory or a
    # .tar.gz file
    if serialization_dir is not None:
        for filename in working_dir.iterdir():
            _copy_allowed_file(Path(filename), repo_local_path)
    else:
        with tempfile.TemporaryDirectory() as temp_dir:
            extracted_dir = Path(
                cached_path(working_dir, temp_dir, extract_archive=True))
            for filename in extracted_dir.iterdir():
                _copy_allowed_file(Path(filename), repo_local_path)

    _create_model_card(repo_local_path)

    logging.info(f"Pushing repo {repo_name} to the Hugging Face Hub")
    repo.push_to_hub(commit_message=commit_message)

    logging.info(f"View your model in {repo_url}")
    return repo_url