예제 #1
0
def delete_repo(
    hf_api: HfApi,
    name: str,
    token: Optional[str] = None,
    organization: Optional[str] = None,
    repo_type: Optional[str] = None,
) -> str:
    """
    The huggingface_hub.HfApi.delete_repo parameters changed in 0.5.0 and some of them were deprecated.
    This function checks the huggingface_hub version to call the right parameters.

    Args:
        hf_api (`huggingface_hub.HfApi`): Hub client
        name (`str`): name of the repository (without the namespace)
        token (`str`, *optional*): user or organization token. Defaults to None.
        organization (`str`, *optional*): namespace for the repository: the username or organization name.
            By default it uses the namespace associated to the token used.
        repo_type (`str`, *optional*):
            Set to `"dataset"` or `"space"` if uploading to a dataset or
            space, `None` or `"model"` if uploading to a model. Default is
            `None`.

    Returns:
        `str`: URL to the newly created repo.
    """
    if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
        return hf_api.delete_repo(
            name=name,
            organization=organization,
            token=token,
            repo_type=repo_type,
        )
    else:  # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
        return hf_api.delete_repo(
            repo_id=f"{organization}/{name}",
            token=token,
            repo_type=repo_type,
        )
예제 #2
0
class TestPushToHub(AllenNlpTestCase):
    def setup_method(self):
        super().setup_method()
        self.api = HfApi(ENDPOINT_STAGING)
        self.token = self.api.login(username=USER, password=PASS)
        self.local_repo_path = self.TEST_DIR / "hub"
        self.clone_path = self.TEST_DIR / "hub_clone"

    def teardown_method(self):
        super().teardown_method()
        try:
            self.api.delete_repo(token=self.token, name=REPO_NAME)
        except requests.exceptions.HTTPError:
            pass

        try:
            self.api.delete_repo(
                token=self.token,
                organization=ORG_NAME,
                name=REPO_NAME,
            )
        except requests.exceptions.HTTPError:
            pass

    @with_staging_testing
    def test_push_to_hub_archive_path(self):
        archive_path = self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz"
        url = push_to_hf(
            repo_name=REPO_NAME,
            archive_path=archive_path,
            local_repo_path=self.local_repo_path,
            use_auth_token=self.token,
        )

        # Check that the returned commit url
        # actually exists.
        r = requests.head(url)
        r.raise_for_status()

        Repository(
            self.clone_path,
            clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}",
            use_auth_token=self.token,
        )
        assert "model.th" in os.listdir(self.clone_path)
        shutil.rmtree(self.clone_path)

    @with_staging_testing
    def test_push_to_hub_serialization_dir(self):
        serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization"
        url = push_to_hf(
            repo_name=REPO_NAME,
            serialization_dir=serialization_dir,
            local_repo_path=self.local_repo_path,
            use_auth_token=self.token,
        )

        # Check that the returned commit url
        # actually exists.
        r = requests.head(url)
        r.raise_for_status()

        Repository(
            self.clone_path,
            clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}",
            use_auth_token=self.token,
        )
        assert "model.th" in os.listdir(self.clone_path)
        shutil.rmtree(self.clone_path)

    @with_staging_testing
    def test_push_to_hub_to_org(self):
        serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization"
        url = push_to_hf(
            repo_name=REPO_NAME,
            serialization_dir=serialization_dir,
            organization=ORG_NAME,
            local_repo_path=self.local_repo_path,
            use_auth_token=self.token,
        )

        # Check that the returned commit url
        # actually exists.
        r = requests.head(url)
        r.raise_for_status()

        Repository(
            self.clone_path,
            clone_from=f"{ENDPOINT_STAGING}/{ORG_NAME}/{REPO_NAME}",
            use_auth_token=self.token,
        )
        assert "model.th" in os.listdir(self.clone_path)
        shutil.rmtree(self.clone_path)

    @with_staging_testing
    def test_push_to_hub_fails_with_invalid_token(self):
        serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization"
        with pytest.raises(requests.exceptions.HTTPError):
            push_to_hf(
                repo_name=REPO_NAME,
                serialization_dir=serialization_dir,
                local_repo_path=self.local_repo_path,
                use_auth_token="invalid token",
            )