示例#1
0
def snapshot_download(repo_id: str,
                      revision: Optional[str] = None,
                      cache_dir: Union[str, Path, None] = None,
                      library_name: Optional[str] = None,
                      library_version: Optional[str] = None,
                      user_agent: Union[Dict, str, None] = None,
                      ignore_files: Optional[List[str]] = None) -> str:
    """
    Method derived from huggingface_hub.
    Adds a new parameters 'ignore_files', which allows to ignore certain files / file-patterns
    """
    if cache_dir is None:
        cache_dir = HUGGINGFACE_HUB_CACHE
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    _api = HfApi()
    model_info = _api.model_info(repo_id=repo_id, revision=revision)

    storage_folder = os.path.join(
        cache_dir,
        repo_id.replace("/", REPO_ID_SEPARATOR) + "." + model_info.sha)

    for model_file in model_info.siblings:
        if ignore_files is not None:
            skip_download = False
            for pattern in ignore_files:
                if fnmatch.fnmatch(model_file.rfilename, pattern):
                    skip_download = True
                    break

            if skip_download:
                continue

        url = hf_hub_url(repo_id,
                         filename=model_file.rfilename,
                         revision=model_info.sha)
        relative_filepath = os.path.join(*model_file.rfilename.split("/"))

        # Create potential nested dir
        nested_dirname = os.path.dirname(
            os.path.join(storage_folder, relative_filepath))
        os.makedirs(nested_dirname, exist_ok=True)

        path = cached_download(
            url,
            cache_dir=storage_folder,
            force_filename=relative_filepath,
            library_name=library_name,
            library_version=library_version,
            user_agent=user_agent,
        )

        if os.path.exists(path + ".lock"):
            os.remove(path + ".lock")

    return storage_folder
示例#2
0
def create_repo(
    hf_api: HfApi,
    name: str,
    token: Optional[str] = None,
    organization: Optional[str] = None,
    private: Optional[bool] = None,
    repo_type: Optional[str] = None,
    exist_ok: Optional[bool] = False,
    space_sdk: Optional[str] = None,
) -> str:
    """
    The huggingface_hub.HfApi.create_repo parameters changed in 0.5.0 and some of them were deprecated.
    This function checks the huggingface_hub version to call the right parameters.

    Args:
        hf_api (`huggingface_hub.HfApi`): Hub client
        name (`str`): name of the repository (without the namespace)
        token (`str`, *optional*): user or organization token. Defaults to None.
        organization (`str`, *optional*): namespace for the repository: the username or organization name.
            By default it uses the namespace associated to the token used.
        private (`bool`, *optional*):
            Whether the model repo should be private.
        repo_type (`str`, *optional*):
            Set to `"dataset"` or `"space"` if uploading to a dataset or
            space, `None` or `"model"` if uploading to a model. Default is
            `None`.
        exist_ok (`bool`, *optional*, defaults to `False`):
            If `True`, do not raise an error if repo already exists.
        space_sdk (`str`, *optional*):
            Choice of SDK to use if repo_type is "space". Can be
            "streamlit", "gradio", or "static".

    Returns:
        `str`: URL to the newly created repo.
    """
    if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
        return hf_api.create_repo(
            name=name,
            organization=organization,
            token=token,
            private=private,
            repo_type=repo_type,
            exist_ok=exist_ok,
            space_sdk=space_sdk,
        )
    else:  # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
        return hf_api.create_repo(
            repo_id=f"{organization}/{name}",
            token=token,
            private=private,
            repo_type=repo_type,
            exist_ok=exist_ok,
            space_sdk=space_sdk,
        )
 def setUpClass(cls):
     """
     Share this valid token in all tests below.
     """
     cls._api = HfApi(endpoint=ENDPOINT_STAGING)
     cls._token = TOKEN
     cls._api.set_access_token(TOKEN)
    def test_push_to_hub_model_kwargs(self):
        REPO_NAME = repo_name("PUSH_TO_HUB")
        model = self.model_init()
        model = self.model_fit(model)
        push_to_hub_keras(
            model,
            repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
            api_endpoint=ENDPOINT_STAGING,
            use_auth_token=self._token,
            git_user="******",
            git_email="*****@*****.**",
            config={
                "num": 7,
                "act": "gelu_fast"
            },
            include_optimizer=True,
            save_traces=False,
        )

        model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
            f"{USER}/{REPO_NAME}", )
        self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}")

        from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}")
        self.assertRaises(ValueError,
                          msg="Exception encountered when calling layer*")

        self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
示例#5
0
def get_type(model_id):
    info = HfApi().model_info(repo_id=model_id)
    if info.config:
        if "speechbrain" in info.config:
            return ModelType(info.config["speechbrain"]["interface"].upper())
        else:
            raise ValueError("speechbrain not in config.json")
    raise ValueError("no config.json in repository")
示例#6
0
    def upload(self, hf_username, model_name, token):
        if token is None:
            token = getpass("Enter your HuggingFace access token")
        if Path("./model").exists():
            shutil.rmtree("./model")
        HfApi.set_access_token(token)
        model_url = HfApi().create_repo(token=token,
                                        name=model_name,
                                        exist_ok=True)
        model_repo = Repository(
            "./model",
            clone_from=model_url,
            use_auth_token=token,
            git_email=f"{hf_username}@users.noreply.huggingface.co",
            git_user=hf_username,
        )

        readme_txt = f"""
        ---
language: "en"
thumbnail: "Keywords to Sentences"
tags:
- keytotext
- k2t
- Keywords to Sentences

model-index:
- name: {model_name}
---

Idea is to build a model which will take keywords as inputs and generate sentences as outputs.

Potential use case can include: 
- Marketing 
- Search Engine Optimization
- Topic generation etc.
- Fine tuning of topic modeling models 
        """.strip()

        (Path(model_repo.local_dir) / "README.md").write_text(readme_txt)
        self.save_model()
        commit_url = model_repo.push_to_hub()

        print("Check out your model at:")
        print(f"https://huggingface.co/{hf_username}/{model_name}")
示例#7
0
def login_to_hub() -> None:
    """Login to huggingface hub"""
    access_token = HfFolder.get_token()
    if access_token is not None and HfApi()._is_valid_token(access_token):
        logging.info("Huggingface Hub token found and valid")
        HfApi().set_access_token(access_token)
    else:
        subprocess.call(["huggingface-cli", "login"])
        HfApi().set_access_token(HfFolder().get_token())
    # check if git lfs is installed
    try:
        subprocess.call(["git", "lfs", "version"])
    except FileNotFoundError:
        raise OSError(
            "Looks like you do not have git-lfs installed, please install. \
                      You can install from https://git-lfs.github.com/. \
                      Then run `git lfs install` (you only have to do this once)."
        )
示例#8
0
 def __post_init__(self):
     # Infer default license from the checkpoint used, if possible.
     if self.license is None and not is_offline_mode() and self.finetuned_from is not None:
         try:
             model_info = HfApi().model_info(self.finetuned_from)
             for tag in model_info.tags:
                 if tag.startswith("license:"):
                     self.license = tag[8:]
         except requests.exceptions.HTTPError:
             pass
示例#9
0
def list_adapters(source: str = None,
                  model_name: str = None) -> List[AdapterInfo]:
    """
    Retrieves a list of all publicly available adapters on AdapterHub.ml or on huggingface.co.

    Args:
        source (str, optional): Identifier of the source(s) from where to get adapters. Can be either:

            - "ah": search on AdapterHub.ml.
            - "hf": search on HuggingFace model hub (huggingface.co).
            - None (default): search on all sources

        model_name (str, optional): If specified, only returns adapters trained for the model with this identifier.
    """
    adapters = []
    if source == "ah" or source is None:
        try:
            all_ah_adapters_file = download_cached(ADAPTER_HUB_ALL_FILE)
        except requests.exceptions.HTTPError:
            raise EnvironmentError(
                "Unable to load list of adapters from AdapterHub.ml. The service might be temporarily unavailable."
            )
        with open(all_ah_adapters_file, "r") as f:
            all_ah_adapters_data = json.load(f)
        adapters += [AdapterInfo(**info) for info in all_ah_adapters_data]
    if source == "hf" or source is None:
        if "fetch_config" in inspect.signature(HfApi.list_models).parameters:
            kwargs = {"full": True, "fetch_config": True}
        else:
            logger.warning(
                "Using old version of huggingface-hub package for fetching. Please upgrade to latest version for accurate results."
            )
            kwargs = {"full": True}
        all_hf_adapters_data = HfApi().list_models(
            filter="adapter-transformers", **kwargs)
        for model_info in all_hf_adapters_data:
            adapter_info = AdapterInfo(
                source="hf",
                adapter_id=model_info.modelId,
                model_name=model_info.config.get("adapter_transformers",
                                                 {}).get("model_name")
                if model_info.config else None,
                username=model_info.modelId.split("/")[0],
                sha1_checksum=model_info.sha,
            )
            adapters.append(adapter_info)

    if model_name is not None:
        adapters = [
            adapter for adapter in adapters if adapter.model_name == model_name
        ]
    return adapters
示例#10
0
def push_to_hf_hub(
    model,
    local_dir,
    repo_namespace_or_url=None,
    commit_message='Add model',
    use_auth_token=True,
    git_email=None,
    git_user=None,
    revision=None,
    model_config=None,
):
    if repo_namespace_or_url:
        repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split(
            '/')[-2:]
    else:
        if isinstance(use_auth_token, str):
            token = use_auth_token
        else:
            token = HfFolder.get_token()

        if token is None:
            raise ValueError(
                "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
                "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
                "token as the `use_auth_token` argument.")

        repo_owner = HfApi().whoami(token)['name']
        repo_name = Path(local_dir).name

    repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'

    repo = Repository(
        local_dir,
        clone_from=repo_url,
        use_auth_token=use_auth_token,
        git_user=git_user,
        git_email=git_email,
        revision=revision,
    )

    # Prepare a default model card that includes the necessary tags to enable inference.
    readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
    with repo.commit(commit_message):
        # Save model weights and config.
        save_for_hf(model, repo.local_dir, model_config=model_config)

        # Save a model card if it doesn't exist.
        readme_path = Path(repo.local_dir) / 'README.md'
        if not readme_path.exists():
            readme_path.write_text(readme_text)

    return repo.git_remote_url()
示例#11
0
def delete_repo(
    hf_api: HfApi,
    name: str,
    token: Optional[str] = None,
    organization: Optional[str] = None,
    repo_type: Optional[str] = None,
) -> str:
    """
    The huggingface_hub.HfApi.delete_repo parameters changed in 0.5.0 and some of them were deprecated.
    This function checks the huggingface_hub version to call the right parameters.

    Args:
        hf_api (`huggingface_hub.HfApi`): Hub client
        name (`str`): name of the repository (without the namespace)
        token (`str`, *optional*): user or organization token. Defaults to None.
        organization (`str`, *optional*): namespace for the repository: the username or organization name.
            By default it uses the namespace associated to the token used.
        repo_type (`str`, *optional*):
            Set to `"dataset"` or `"space"` if uploading to a dataset or
            space, `None` or `"model"` if uploading to a model. Default is
            `None`.

    Returns:
        `str`: URL to the newly created repo.
    """
    if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
        return hf_api.delete_repo(
            name=name,
            organization=organization,
            token=token,
            repo_type=repo_type,
        )
    else:  # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
        return hf_api.delete_repo(
            repo_id=f"{organization}/{name}",
            token=token,
            repo_type=repo_type,
        )
示例#12
0
    def __init__(self,
                 dataset=None,
                 device='cpu',
                 extract='tokens',
                 randomize=True,
                 remove_stopwords=True,
                 lemmatizer=None):

        self.dataset = dataset
        self.device = device
        self.extract = extract
        self.randomize = randomize
        self.remove_stopwords = remove_stopwords
        self.model = None
        self.tokenizer = None
        self.interpreter = None
        self.lemmatizer = WordNetLemmatizer(
        ).lemmatize if lemmatizer is None else lemmatizer

        if dataset is not None:
            api = HfApi()
            # find huggingface model to provide rationalized output
            modelIds = api.list_models(filter=("pytorch", "dataset:" + dataset,
                                               "sibyl"))
            if modelIds:
                modelId = getattr(modelIds[0], 'modelId')
                print('Using ' + modelId +
                      ' to rationalize keyphrase selections.')
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    modelId).to(self.device)
                self.tokenizer = AutoTokenizer.from_pretrained(modelId)
                self.interpreter = SequenceClassificationExplainer(
                    self.model, self.tokenizer)
        else:
            self.extract = "concepts"
        self.stops = stopwords.words(
            'english') if self.remove_stopwords else []
 def test_push_to_hub_model_card_plot_false(self):
     REPO_NAME = repo_name("PUSH_TO_HUB")
     model = self.model_init()
     model = self.model_fit(model)
     push_to_hub_keras(
         model,
         repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
         api_endpoint=ENDPOINT_STAGING,
         use_auth_token=self._token,
         git_user="******",
         git_email="*****@*****.**",
         plot_model=False,
     )
     model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
         f"{USER}/{REPO_NAME}", )
     self.assertFalse(
         "model.png" in [f.rfilename for f in model_info.siblings])
     self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
示例#14
0
def get_list_of_files(
    path_or_repo: Union[str, os.PathLike],
    revision: Optional[str] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
) -> List[str]:
    """
    Gets the list of files inside :obj:`path_or_repo`.

    Args:
        path_or_repo (:obj:`str` or :obj:`os.PathLike`):
            Can be either the id of a repo on huggingface.co or a path to a `directory`.
        revision (:obj:`str`, `optional`, defaults to :obj:`"main"`):
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
            git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
            identifier allowed by git.
        use_auth_token (:obj:`str` or `bool`, `optional`):
            The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
            generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).

    Returns:
        :obj:`List[str]`: The list of files available in :obj:`path_or_repo`.
    """
    path_or_repo = str(path_or_repo)
    # If path_or_repo is a folder, we just return what is inside (subdirectories included).
    if os.path.isdir(path_or_repo):
        list_of_files = []
        for path, dir_names, file_names in os.walk(path_or_repo):
            list_of_files.extend([os.path.join(path, f) for f in file_names])
        return list_of_files

    # Can't grab the files if we are on offline mode.
    if is_offline_mode():
        return []

    # Otherwise we grab the token and use the model_info method.
    if isinstance(use_auth_token, str):
        token = use_auth_token
    elif use_auth_token is True:
        token = HfFolder.get_token()
    else:
        token = None
    model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info(
        path_or_repo, revision=revision, token=token)
    return [f.rfilename for f in model_info.siblings]
示例#15
0
def resolve(model_id: str) -> [str, str]:
    try:
        info = HfApi().model_info(model_id)
    except Exception as e:
        raise ValueError(
            f"The hub has no information on {model_id}, does it exist: {e}"
        )
    try:
        task = info.pipeline_tag
    except Exception:
        raise ValueError(
            f"The hub has no `pipeline_tag` on {model_id}, you can set it in the `README.md` yaml header"
        )
    try:
        framework = info.library_name
    except Exception:
        raise ValueError(
            f"The hub has no `library_name` on {model_id}, you can set it in the `README.md` yaml header"
        )
    return task, framework.replace("-", "_")
示例#16
0
    def test_push_to_hub(self):
        REPO_NAME = repo_name("PUSH_TO_HUB")
        model = self.model_init()
        model.build((None, 2))
        push_to_hub_keras(
            model,
            repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
            api_endpoint=ENDPOINT_STAGING,
            use_auth_token=self._token,
            git_user="******",
            git_email="*****@*****.**",
            config={
                "num": 7,
                "act": "gelu_fast"
            },
        )

        model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
            f"{USER}/{REPO_NAME}", )
        self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}")

        self._api.delete_repo(name=f"{REPO_NAME}", token=self._token)
示例#17
0
def get_adapter_info(adapter_id: str,
                     source: str = "ah") -> Optional[AdapterInfo]:
    """
    Retrieves information about a specific adapter.

    Args:
        adapter_id (str): The identifier of the adapter to retrieve.
        source (str, optional): Identifier of the source(s) from where to get adapters. Can be either:

            - "ah": search on AdapterHub.ml.
            - "hf": search on HuggingFace model hub (huggingface.co).

    Returns:
        AdapterInfo: The adapter information or None if the adapter was not found.
    """
    if source == "ah":
        if adapter_id.startswith("@"):
            adapter_id = adapter_id[1:]
        try:
            data = http_get_json(f"/adapters/{adapter_id}.json")
            return AdapterInfo(**data["info"])
        except EnvironmentError:
            return None
    elif source == "hf":
        try:
            model_info = HfApi().model_info(adapter_id)
            return AdapterInfo(
                source="hf",
                adapter_id=model_info.modelId,
                model_name=model_info.config.get("adapter_transformers",
                                                 {}).get("model_name")
                if model_info.config else None,
                username=model_info.modelId.split("/")[0],
                sha1_checksum=model_info.sha,
            )
        except requests.exceptions.HTTPError:
            return None
    else:
        raise ValueError("Please specify either 'ah' or 'hf' as source.")
    def test_override_tensorboard(self):
        REPO_NAME = repo_name("PUSH_TO_HUB")
        with tempfile.TemporaryDirectory() as tmpdirname:
            os.makedirs(f"{tmpdirname}/tb_log_dir")
            with open(f"{tmpdirname}/tb_log_dir/tensorboard.txt", "w") as fp:
                fp.write("Keras FTW")
            model = self.model_init()
            model.build((None, 2))
            push_to_hub_keras(
                model,
                repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
                log_dir=f"{tmpdirname}/tb_log_dir",
                api_endpoint=ENDPOINT_STAGING,
                use_auth_token=self._token,
                git_user="******",
                git_email="*****@*****.**",
            )
            os.makedirs(f"{tmpdirname}/tb_log_dir2")
            with open(f"{tmpdirname}/tb_log_dir2/override.txt", "w") as fp:
                fp.write("Keras FTW")
            push_to_hub_keras(
                model,
                repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}",
                log_dir=f"{tmpdirname}/tb_log_dir2",
                api_endpoint=ENDPOINT_STAGING,
                use_auth_token=self._token,
                git_user="******",
                git_email="*****@*****.**",
            )

            model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(
                f"{USER}/{REPO_NAME}", )
            self.assertTrue("logs/override.txt" in
                            [f.rfilename for f in model_info.siblings])
            self.assertFalse("logs/tensorboard.txt" in
                             [f.rfilename for f in model_info.siblings])

            self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token)
示例#19
0
    def test_full_deserialization_hub(self):
        # Check we can read this file.
        # This used to fail because of BufReader that would fail because the
        # file exceeds the buffer capacity
        api = HfApi()

        not_loadable = []
        invalid_pre_tokenizer = []

        # models = api.list_models(filter="transformers")
        # for model in tqdm.tqdm(models):
        #     model_id = model.modelId
        #     for model_file in model.siblings:
        #         filename = model_file.rfilename
        #         if filename == "tokenizer.json":
        #             all_models.append((model_id, filename))

        all_models = [("HueyNemud/das22-10-camembert_pretrained",
                       "tokenizer.json")]
        for (model_id, filename) in tqdm.tqdm(all_models):
            tokenizer_file = cached_download(
                hf_hub_url(model_id, filename=filename))

            is_ok = check(tokenizer_file)
            if not is_ok:
                print(f"{model_id} is affected by no type")
                invalid_pre_tokenizer.append(model_id)
            try:
                Tokenizer.from_file(tokenizer_file)
            except Exception as e:
                print(f"{model_id} is not loadable: {e}")
                not_loadable.append(model_id)
            except:
                print(f"{model_id} is not loadable: Rust error")
                not_loadable.append(model_id)

            self.assertEqual(invalid_pre_tokenizer, [])
            self.assertEqual(not_loadable, [])
示例#20
0
    def save_to_hub(self,
                    repo_name: str,
                    organization: Optional[str] = None,
                    private: Optional[bool] = None,
                    commit_message: str = "Add new SentenceTransformer model.",
                    local_model_path: Optional[str] = None,
                    exist_ok: bool = False,
                    replace_model_card: bool = False):
        """
        Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.

        :param repo_name: Repository name for your model in the Hub.
        :param organization:  Organization in which you want to push your model or tokenizer (you must be a member of this organization).
        :param private: Set to true, for hosting a prive model
        :param commit_message: Message to commit while pushing.
        :param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
        :param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
        :param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card
        :return: The url of the commit of your model in the given repository.
        """
        token = HfFolder.get_token()
        if token is None:
            raise ValueError(
                "You must login to the Hugging Face hub on this computer by typing `transformers-cli login`."
            )

        if '/' in repo_name:
            splits = repo_name.split('/', maxsplit=1)
            if organization is None or organization == splits[0]:
                organization = splits[0]
                repo_name = splits[1]
            else:
                raise ValueError(
                    "You passed and invalid repository name: {}.".format(
                        repo_name))

        endpoint = "https://huggingface.co"
        repo_url = HfApi(endpoint=endpoint).create_repo(
            token,
            repo_name,
            organization=organization,
            private=private,
            repo_type=None,
            exist_ok=exist_ok,
        )
        full_model_name = repo_url[len(endpoint) + 1:].strip("/")

        with tempfile.TemporaryDirectory() as tmp_dir:
            # First create the repo (and clone its content if it's nonempty).
            logging.info("Create repository and clone it if it exists")
            repo = Repository(tmp_dir, clone_from=repo_url)

            # If user provides local files, copy them.
            if local_model_path:
                copy_tree(local_model_path, tmp_dir)
            else:  # Else, save model directly into local repo.
                create_model_card = replace_model_card or not os.path.exists(
                    os.path.join(tmp_dir, 'README.md'))
                self.save(tmp_dir,
                          model_name=full_model_name,
                          create_model_card=create_model_card)

            #Find files larger 5M and track with git-lfs
            large_files = []
            for root, dirs, files in os.walk(tmp_dir):
                for filename in files:
                    file_path = os.path.join(root, filename)
                    rel_path = os.path.relpath(file_path, tmp_dir)

                    if os.path.getsize(file_path) > (5 * 1024 * 1024):
                        large_files.append(rel_path)

            if len(large_files) > 0:
                logging.info("Track files with git lfs: {}".format(
                    ", ".join(large_files)))
                repo.lfs_track(large_files)

            logging.info("Push model to the hub. This might take a while")
            push_return = repo.push_to_hub(commit_message=commit_message)

            def on_rm_error(func, path, exc_info):
                # path contains the path of the file that couldn't be removed
                # let's just assume that it's read-only and unlink it.
                try:
                    os.chmod(path, stat.S_IWRITE)
                    os.unlink(path)
                except:
                    pass

            # Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted
            # Hence, try to set write permissions on error
            try:
                for f in os.listdir(tmp_dir):
                    shutil.rmtree(os.path.join(tmp_dir, f),
                                  onerror=on_rm_error)
            except Exception as e:
                logging.warning("Error when deleting temp folder: {}".format(
                    str(e)))
                pass

        return push_return
示例#21
0
class TestPushToHub(TestCase):
    _api = HfApi(endpoint=ENDPOINT_STAGING)

    @classmethod
    def setUpClass(cls):
        """
        Share this valid token in all tests below.
        """
        cls._hf_folder_patch = patch(
            "huggingface_hub.hf_api.HfFolder.path_token",
            TOKEN_PATH_STAGING,
        )
        cls._hf_folder_patch.start()

        cls._token = cls._api.login(username=USER, password=PASS)
        HfFolder.save_token(cls._token)

    @classmethod
    def tearDownClass(cls) -> None:
        HfFolder.delete_token()
        cls._hf_folder_patch.stop()

    def test_push_dataset_dict_to_hub_no_token(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

            # Ensure that there is a single file on the repository that has the correct name
            files = sorted(
                self._api.list_repo_files(ds_name, repo_type="dataset"))
            self.assertListEqual(files, [
                ".gitattributes", "data/train-00000-of-00001.parquet",
                "dataset_infos.json"
            ])
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  repo_type="dataset")

    def test_push_dataset_dict_to_hub_name_without_namespace(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

            # Ensure that there is a single file on the repository that has the correct name
            files = sorted(
                self._api.list_repo_files(ds_name, repo_type="dataset"))
            self.assertListEqual(files, [
                ".gitattributes", "data/train-00000-of-00001.parquet",
                "dataset_infos.json"
            ])
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  repo_type="dataset")

    def test_push_dataset_dict_to_hub_private(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, token=self._token, private=True)
            hub_ds = load_dataset(ds_name,
                                  download_mode="force_redownload",
                                  use_auth_token=self._token)

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

            # Ensure that there is a single file on the repository that has the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(files, [
                ".gitattributes", "data/train-00000-of-00001.parquet",
                "dataset_infos.json"
            ])
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_dict_to_hub(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

            # Ensure that there is a single file on the repository that has the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(files, [
                ".gitattributes", "data/train-00000-of-00001.parquet",
                "dataset_infos.json"
            ])
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_dict_to_hub_multiple_files(self):
        ds = Dataset.from_dict({
            "x": list(range(1000)),
            "y": list(range(1000))
        })

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name,
                                 token=self._token,
                                 shard_size=500 << 5)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

            # Ensure that there are two files on the repository that have the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(
                files,
                [
                    ".gitattributes",
                    "data/train-00000-of-00002.parquet",
                    "data/train-00001-of-00002.parquet",
                    "dataset_infos.json",
                ],
            )
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_dict_to_hub_overwrite_files(self):
        ds = Dataset.from_dict({
            "x": list(range(1000)),
            "y": list(range(1000))
        })
        ds2 = Dataset.from_dict({"x": list(range(100)), "y": list(range(100))})

        local_ds = DatasetDict({"train": ds, "random": ds2})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"

        # Push to hub two times, but the second time with a larger amount of files.
        # Verify that the new files contain the correct dataset.
        try:
            local_ds.push_to_hub(ds_name, token=self._token)

            with tempfile.TemporaryDirectory() as tmp:
                # Add a file starting with "data" to ensure it doesn't get deleted.
                path = Path(tmp) / "datafile.txt"
                with open(path, "w") as f:
                    f.write("Bogus file")

                self._api.upload_file(str(path),
                                      path_in_repo="datafile.txt",
                                      repo_id=ds_name,
                                      repo_type="dataset",
                                      token=self._token)

            local_ds.push_to_hub(ds_name,
                                 token=self._token,
                                 shard_size=500 << 5)

            # Ensure that there are two files on the repository that have the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(
                files,
                [
                    ".gitattributes",
                    "data/random-00000-of-00001.parquet",
                    "data/train-00000-of-00002.parquet",
                    "data/train-00001-of-00002.parquet",
                    "datafile.txt",
                    "dataset_infos.json",
                ],
            )

            self._api.delete_file("datafile.txt",
                                  repo_id=ds_name,
                                  repo_type="dataset",
                                  token=self._token)

            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

        # Push to hub two times, but the second time with fewer files.
        # Verify that the new files contain the correct dataset and that non-necessary files have been deleted.
        try:
            local_ds.push_to_hub(ds_name,
                                 token=self._token,
                                 shard_size=500 << 5)

            with tempfile.TemporaryDirectory() as tmp:
                # Add a file starting with "data" to ensure it doesn't get deleted.
                path = Path(tmp) / "datafile.txt"
                with open(path, "w") as f:
                    f.write("Bogus file")

                self._api.upload_file(str(path),
                                      path_in_repo="datafile.txt",
                                      repo_id=ds_name,
                                      repo_type="dataset",
                                      token=self._token)

            local_ds.push_to_hub(ds_name, token=self._token)

            # Ensure that there are two files on the repository that have the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(
                files,
                [
                    ".gitattributes",
                    "data/random-00000-of-00001.parquet",
                    "data/train-00000-of-00001.parquet",
                    "datafile.txt",
                    "dataset_infos.json",
                ],
            )

            # Keeping the "datafile.txt" breaks the load_dataset to think it's a text-based dataset
            self._api.delete_file("datafile.txt",
                                  repo_id=ds_name,
                                  repo_type="dataset",
                                  token=self._token)

            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_to_hub(self):
        local_ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, split="train", token=self._token)
            local_ds_dict = {"train": local_ds}
            hub_ds_dict = load_dataset(ds_name,
                                       download_mode="force_redownload")

            self.assertListEqual(list(local_ds_dict.keys()),
                                 list(hub_ds_dict.keys()))

            for ds_split_name in local_ds_dict.keys():
                local_ds = local_ds_dict[ds_split_name]
                hub_ds = hub_ds_dict[ds_split_name]
                self.assertListEqual(local_ds.column_names,
                                     hub_ds.column_names)
                self.assertListEqual(list(local_ds.features.keys()),
                                     list(hub_ds.features.keys()))
                self.assertDictEqual(local_ds.features, hub_ds.features)
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_to_hub_custom_features(self):
        features = Features({
            "x": Value("int64"),
            "y": ClassLabel(names=["neg", "pos"])
        })
        ds = Dataset.from_dict({
            "x": [1, 2, 3],
            "y": [0, 0, 1]
        },
                               features=features)

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            ds.push_to_hub(ds_name, token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertListEqual(ds.column_names, hub_ds["train"].column_names)
            self.assertListEqual(list(ds.features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(ds.features, hub_ds["train"].features)
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_dict_to_hub_custom_features(self):
        features = Features({
            "x": Value("int64"),
            "y": ClassLabel(names=["neg", "pos"])
        })
        ds = Dataset.from_dict({
            "x": [1, 2, 3],
            "y": [0, 0, 1]
        },
                               features=features)

        local_ds = DatasetDict({"test": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["test"].features.keys()),
                                 list(hub_ds["test"].features.keys()))
            self.assertDictEqual(local_ds["test"].features,
                                 hub_ds["test"].features)
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_to_hub_custom_splits(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            ds.push_to_hub(ds_name, split="random", token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertListEqual(ds.column_names,
                                 hub_ds["random"].column_names)
            self.assertListEqual(list(ds.features.keys()),
                                 list(hub_ds["random"].features.keys()))
            self.assertDictEqual(ds.features, hub_ds["random"].features)
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    def test_push_dataset_dict_to_hub_custom_splits(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"random": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["random"].features.keys()),
                                 list(hub_ds["random"].features.keys()))
            self.assertDictEqual(local_ds["random"].features,
                                 hub_ds["random"].features)
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

    @unittest.skip(
        "This test cannot pass until iterable datasets have push to hub")
    def test_push_streaming_dataset_dict_to_hub(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})
        local_ds = DatasetDict({"train": ds})
        with tempfile.TemporaryDirectory() as tmp:
            local_ds.save_to_disk(tmp)
            local_ds = load_dataset(tmp, streaming=True)

            ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
            try:
                local_ds.push_to_hub(ds_name, token=self._token)
                hub_ds = load_dataset(ds_name,
                                      download_mode="force_redownload")

                self.assertDictEqual(local_ds.column_names,
                                     hub_ds.column_names)
                self.assertListEqual(list(local_ds["train"].features.keys()),
                                     list(hub_ds["train"].features.keys()))
                self.assertDictEqual(local_ds["train"].features,
                                     hub_ds["train"].features)
            finally:
                self._api.delete_repo(ds_name.split("/")[1],
                                      organization=ds_name.split("/")[0],
                                      token=self._token,
                                      repo_type="dataset")
示例#22
0
class TestPushToHub(AllenNlpTestCase):
    def setup_method(self):
        super().setup_method()
        self.api = HfApi(ENDPOINT_STAGING)
        self.token = self.api.login(username=USER, password=PASS)
        self.local_repo_path = self.TEST_DIR / "hub"
        self.clone_path = self.TEST_DIR / "hub_clone"

    def teardown_method(self):
        super().teardown_method()
        try:
            self.api.delete_repo(token=self.token, name=REPO_NAME)
        except requests.exceptions.HTTPError:
            pass

        try:
            self.api.delete_repo(
                token=self.token,
                organization=ORG_NAME,
                name=REPO_NAME,
            )
        except requests.exceptions.HTTPError:
            pass

    @with_staging_testing
    def test_push_to_hub_archive_path(self):
        archive_path = self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz"
        url = push_to_hf(
            repo_name=REPO_NAME,
            archive_path=archive_path,
            local_repo_path=self.local_repo_path,
            use_auth_token=self.token,
        )

        # Check that the returned commit url
        # actually exists.
        r = requests.head(url)
        r.raise_for_status()

        Repository(
            self.clone_path,
            clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}",
            use_auth_token=self.token,
        )
        assert "model.th" in os.listdir(self.clone_path)
        shutil.rmtree(self.clone_path)

    @with_staging_testing
    def test_push_to_hub_serialization_dir(self):
        serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization"
        url = push_to_hf(
            repo_name=REPO_NAME,
            serialization_dir=serialization_dir,
            local_repo_path=self.local_repo_path,
            use_auth_token=self.token,
        )

        # Check that the returned commit url
        # actually exists.
        r = requests.head(url)
        r.raise_for_status()

        Repository(
            self.clone_path,
            clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}",
            use_auth_token=self.token,
        )
        assert "model.th" in os.listdir(self.clone_path)
        shutil.rmtree(self.clone_path)

    @with_staging_testing
    def test_push_to_hub_to_org(self):
        serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization"
        url = push_to_hf(
            repo_name=REPO_NAME,
            serialization_dir=serialization_dir,
            organization=ORG_NAME,
            local_repo_path=self.local_repo_path,
            use_auth_token=self.token,
        )

        # Check that the returned commit url
        # actually exists.
        r = requests.head(url)
        r.raise_for_status()

        Repository(
            self.clone_path,
            clone_from=f"{ENDPOINT_STAGING}/{ORG_NAME}/{REPO_NAME}",
            use_auth_token=self.token,
        )
        assert "model.th" in os.listdir(self.clone_path)
        shutil.rmtree(self.clone_path)

    @with_staging_testing
    def test_push_to_hub_fails_with_invalid_token(self):
        serialization_dir = self.FIXTURES_ROOT / "simple_tagger" / "serialization"
        with pytest.raises(requests.exceptions.HTTPError):
            push_to_hf(
                repo_name=REPO_NAME,
                serialization_dir=serialization_dir,
                local_repo_path=self.local_repo_path,
                use_auth_token="invalid token",
            )
示例#23
0
def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None:
    """Save model and its configuration on HF hub

    >>> from doctr.models import login_to_hub, push_to_hf_hub
    >>> from doctr.models.recognition import crnn_mobilenet_v3_small
    >>> login_to_hub()
    >>> model = crnn_mobilenet_v3_small(pretrained=True)
    >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')

    Args:
        model: TF or PyTorch model to be saved
        model_name: name of the model which is also the repository name
        task: task name
        **kwargs: keyword arguments for push_to_hf_hub
    """
    run_config = kwargs.get("run_config", None)
    arch = kwargs.get("arch", None)

    if run_config is None and arch is None:
        raise ValueError("run_config or arch must be specified")
    if task not in [
            "classification", "detection", "recognition", "obj_detection"
    ]:
        raise ValueError(
            "task must be one of classification, detection, recognition, obj_detection"
        )

    # default readme
    readme = textwrap.dedent(f"""
    ---
    language: en
    ---

    <p align="center">
    <img src="https://github.com/mindee/doctr/releases/download/v0.3.1/Logo_doctr.gif" width="60%">
    </p>

    **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**

    ## Task: {task}

    https://github.com/mindee/doctr

    ### Example usage:

    ```python
    >>> from doctr.io import DocumentFile
    >>> from doctr.models import ocr_predictor, from_hub

    >>> img = DocumentFile.from_images(['<image_path>'])
    >>> # Load your model from the hub
    >>> model = from_hub('mindee/my-model')

    >>> # Pass it to the predictor
    >>> # If your model is a recognition model:
    >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
    >>>                           reco_arch=model,
    >>>                           pretrained=True)

    >>> # If your model is a detection model:
    >>> predictor = ocr_predictor(det_arch=model,
    >>>                           reco_arch='crnn_mobilenet_v3_small',
    >>>                           pretrained=True)

    >>> # Get your predictions
    >>> res = predictor(img)
    ```
    """)

    # add run configuration to readme if available
    if run_config is not None:
        arch = run_config.arch
        readme += textwrap.dedent(f"""### Run Configuration
                                  \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
                                  )

    if arch not in AVAILABLE_ARCHS[task]:  # type: ignore
        raise ValueError(f"Architecture: {arch} for task: {task} not found.\
                         \nAvailable architectures: {AVAILABLE_ARCHS}")

    commit_message = f"Add {model_name} model"

    local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache",
                                   "huggingface", "hub", model_name)
    repo_url = HfApi().create_repo(model_name,
                                   token=HfFolder.get_token(),
                                   exist_ok=False)
    repo = Repository(local_dir=local_cache_dir,
                      clone_from=repo_url,
                      use_auth_token=True)

    with repo.commit(commit_message):

        _save_model_and_config_for_hf_hub(model,
                                          repo.local_dir,
                                          arch=arch,
                                          task=task)
        readme_path = Path(repo.local_dir) / "README.md"
        readme_path.write_text(readme)

    repo.git_push()
示例#24
0
 def setup_method(self):
     super().setup_method()
     self.api = HfApi(ENDPOINT_STAGING)
     self.token = self.api.login(username=USER, password=PASS)
     self.local_repo_path = self.TEST_DIR / "hub"
     self.clone_path = self.TEST_DIR / "hub_clone"
示例#25
0
def push_to_hf(
    repo_name: str,
    serialization_dir: Optional[Union[str, PathLike]] = None,
    archive_path: Optional[Union[str, PathLike]] = None,
    organization: Optional[str] = None,
    commit_message: str = "Update repository",
    local_repo_path: Union[str, PathLike] = "hub",
    use_auth_token: Union[bool, str] = True,
) -> str:
    """Pushes model and related files to the Hugging Face Hub ([hf.co](https://hf.co/))

    # Parameters

    repo_name: `str`
        Name of the repository in the Hugging Face Hub.

    serialization_dir : `Union[str, PathLike]`, optional (default = `None`)
        Full path to a directory with the serialized model.

    archive_path : `Union[str, PathLike]`, optional (default = `None`)
        Full path to the zipped model (e.g. model/model.tar.gz). Use `serialization_dir` if possible.

    organization : `Optional[str]`, optional (default = `None`)
        Name of organization to which the model should be uploaded.

    commit_message: `str` (default=`Update repository`)
        Commit message to use for the push.

    local_repo_path : `Union[str, Path]`, optional (default=`hub`)
        Local directory where the repository will be saved.

    use_auth_token (``str`` or ``bool``, `optional`, defaults ``True``):
        huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate
        against the Hugging Face Hub (useful from Google Colab for instance). It's automatically retrieved
        if you've done `huggingface-cli login` before.
    """

    if serialization_dir is not None:
        working_dir = Path(serialization_dir)
        if archive_path is not None:
            raise ValueError(
                "serialization_dir and archive_path are mutually exclusive, please just use one."
            )
        if not working_dir.exists() or not working_dir.is_dir():
            raise ValueError(
                f"Can't find path: {serialization_dir}, please point"
                "to a directory with the serialized model.")
    elif archive_path is not None:
        working_dir = Path(archive_path)
        if (not working_dir.exists() or not zipfile.is_zipfile(working_dir)
                and not tarfile.is_tarfile(working_dir)):
            raise ValueError(
                f"Can't find path: {archive_path}, please point to a .tar.gz archive"
                "or to a directory with the serialized model.")
        else:
            logging.info(
                "Using the archive_path is discouraged. Using the serialization_dir"
                "will also upload metrics and TensorBoard traces to the Hugging Face Hub."
            )
    else:
        raise ValueError(
            "please specify either serialization_dir or archive_path")

    info_msg = f"Preparing repository '{use_auth_token}'"
    if isinstance(use_auth_token, str):
        huggingface_token = use_auth_token
    elif use_auth_token:
        huggingface_token = HfFolder.get_token()

    # Create the repo (or clone its content if it's nonempty)
    api = HfApi()
    repo_url = api.create_repo(
        name=repo_name,
        token=huggingface_token,
        organization=organization,
        private=False,
        exist_ok=True,
    )

    repo_local_path = Path(local_repo_path) / repo_name
    repo = Repository(repo_local_path,
                      clone_from=repo_url,
                      use_auth_token=use_auth_token)
    repo.git_pull(rebase=True)

    # Model file should be tracked with Git LFS
    repo.lfs_track(["*.th"])
    info_msg = f"Preparing repository '{repo_name}'"
    if organization is not None:
        info_msg += f" ({organization})"
    logging.info(info_msg)

    # Extract information from either serializable directory or a
    # .tar.gz file
    if serialization_dir is not None:
        for filename in working_dir.iterdir():
            _copy_allowed_file(Path(filename), repo_local_path)
    else:
        with tempfile.TemporaryDirectory() as temp_dir:
            extracted_dir = Path(
                cached_path(working_dir, temp_dir, extract_archive=True))
            for filename in extracted_dir.iterdir():
                _copy_allowed_file(Path(filename), repo_local_path)

    _create_model_card(repo_local_path)

    logging.info(f"Pushing repo {repo_name} to the Hugging Face Hub")
    repo.push_to_hub(commit_message=commit_message)

    logging.info(f"View your model in {repo_url}")
    return repo_url
示例#26
0
class HubMixingCommonTest(unittest.TestCase):
    _api = HfApi(endpoint=ENDPOINT_STAGING)
示例#27
0
文件: runner.py 项目: kaen2891/s3prl
    def push_to_huggingface_hub(self):
        """Creates a downstream repository on the Hub and pushes training artifacts to it."""
        if self.args.hf_hub_org.lower() != "none":
            organization = self.args.hf_hub_org
        else:
            organization = os.environ.get("HF_USERNAME")
        huggingface_token = HfFolder.get_token()
        print(f"[Runner] - Organisation to push fine-tuned model to: {organization}")
        
        # Extract upstream repository metadata
        if self.args.hub == "huggingface":
            model_info = HfApi().model_info(self.args.upstream, token=huggingface_token)
            downstream_model_id = model_info.sha
            # Exclude "/" characters from downstream repo ID
            upstream_model_id = model_info.modelId.replace("/", "__")
        else:
            upstream_model_id = self.args.upstream.replace("/", "__")
            downstream_model_id = str(uuid.uuid4())[:8]
        repo_name = f"{upstream_model_id}__{downstream_model_id}"
        # Create downstream repo on the Hub
        repo_url = HfApi().create_repo(
            token=huggingface_token,
            name=repo_name,
            organization=organization,
            exist_ok=True,
            private=False,
        )
        print(f"[Runner] - Created Hub repo: {repo_url}")

        # Download repo
        HF_HUB_DIR = "hf_hub"
        REPO_ROOT_DIR = os.path.join(self.args.expdir, HF_HUB_DIR, repo_name)
        REPO_TASK_DIR = os.path.join(REPO_ROOT_DIR, self.args.downstream, self.args.expname)
        print(f"[Runner] - Cloning Hub repo to {REPO_ROOT_DIR}")
        model_repo = Repository(
            local_dir=REPO_ROOT_DIR, clone_from=repo_url, use_auth_token=huggingface_token
        )
        # Pull latest changes if they exist
        model_repo.git_pull()

        # Copy checkpoints, tensorboard logs, and args / configs
        # Note that this copies all files from the experiment directory,
        # including those from multiple runs
        shutil.copytree(self.args.expdir, REPO_TASK_DIR, dirs_exist_ok=True, ignore=shutil.ignore_patterns(HF_HUB_DIR))

        # By default we use model.ckpt in the PreTrainedModel interface, so
        # rename the best checkpoint to match this convention
        checkpoints = list(Path(REPO_TASK_DIR).glob("*best*.ckpt"))
        if len(checkpoints) == 0:
            print("[Runner] - Did not find a best checkpoint! Using the final checkpoint instead ...")
            CKPT_PATH = (
                os.path.join(REPO_TASK_DIR, f"states-{self.config['runner']['total_steps']}.ckpt")
                )
        elif len(checkpoints) > 1:
            print(f"[Runner] - More than one best checkpoint found! Using {checkpoints[0]} as default ...")
            CKPT_PATH = checkpoints[0]
        else:
            print(f"[Runner] - Found best checkpoint {checkpoints[0]}!")
            CKPT_PATH = checkpoints[0]
        shutil.move(CKPT_PATH, os.path.join(REPO_TASK_DIR, "model.ckpt"))
        model_repo.lfs_track("*.ckpt")

        # Write model card
        self._create_model_card(REPO_ROOT_DIR)

        # Push everything to the Hub
        print("[Runner] - Pushing model files to the Hub ...")
        model_repo.push_to_hub()
        print("[Runner] - Training run complete!")
示例#28
0
class SnapshotDownloadTests(unittest.TestCase):
    _api = HfApi(endpoint=ENDPOINT_STAGING)

    @classmethod
    def setUpClass(cls):
        """
        Share this valid token in all tests below.
        """
        cls._token = TOKEN
        cls._api.set_access_token(TOKEN)

    @retry_endpoint
    def setUp(self) -> None:
        if os.path.exists(REPO_NAME):
            shutil.rmtree(REPO_NAME, onerror=set_write_permission_and_retry)
        logger.info(f"Does {REPO_NAME} exist: {os.path.exists(REPO_NAME)}")
        repo = Repository(
            REPO_NAME,
            clone_from=f"{USER}/{REPO_NAME}",
            use_auth_token=self._token,
            git_user="******",
            git_email="*****@*****.**",
        )

        with repo.commit("Add file to main branch"):
            with open("dummy_file.txt", "w+") as f:
                f.write("v1")

        self.first_commit_hash = repo.git_head_hash()

        with repo.commit("Add file to main branch"):
            with open("dummy_file.txt", "w+") as f:
                f.write("v2")
            with open("dummy_file_2.txt", "w+") as f:
                f.write("v3")

        self.second_commit_hash = repo.git_head_hash()

        with repo.commit("Add file to other branch", branch="other"):
            with open("dummy_file_2.txt", "w+") as f:
                f.write("v4")

        self.third_commit_hash = repo.git_head_hash()

    def tearDown(self) -> None:
        self._api.delete_repo(repo_id=REPO_NAME, token=self._token)
        shutil.rmtree(REPO_NAME)

    def test_download_model(self):
        # Test `main` branch
        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

        # Test with specific revision
        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision=self.first_commit_hash,
                cache_dir=tmpdirname,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 2)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v1")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.first_commit_hash in storage_folder)

    def test_download_private_model(self):
        self._api.update_repo_visibility(
            token=self._token, repo_id=REPO_NAME, private=True
        )

        # Test download fails without token
        with tempfile.TemporaryDirectory() as tmpdirname:
            with self.assertRaisesRegex(
                requests.exceptions.HTTPError, "404 Client Error"
            ):
                _ = snapshot_download(
                    f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname
                )

        # Test we can download with token from cache
        with tempfile.TemporaryDirectory() as tmpdirname:
            HfFolder.save_token(self._token)
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="main",
                cache_dir=tmpdirname,
                use_auth_token=True,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

        # Test we can download with explicit token
        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="main",
                cache_dir=tmpdirname,
                use_auth_token=self._token,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

        self._api.update_repo_visibility(
            token=self._token, repo_id=REPO_NAME, private=False
        )

    def test_download_model_local_only(self):
        # Test no branch specified
        with tempfile.TemporaryDirectory() as tmpdirname:
            # first download folder to cache it
            snapshot_download(f"{USER}/{REPO_NAME}", cache_dir=tmpdirname)

            # now load from cache
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                cache_dir=tmpdirname,
                local_files_only=True,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

        # Test with specific revision branch
        with tempfile.TemporaryDirectory() as tmpdirname:
            # first download folder to cache it
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="other",
                cache_dir=tmpdirname,
            )

            # now load from cache
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="other",
                cache_dir=tmpdirname,
                local_files_only=True,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.third_commit_hash in storage_folder)

        # Test with specific revision hash
        with tempfile.TemporaryDirectory() as tmpdirname:
            # first download folder to cache it
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision=self.first_commit_hash,
                cache_dir=tmpdirname,
            )

            # now load from cache
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision=self.first_commit_hash,
                cache_dir=tmpdirname,
                local_files_only=True,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 2)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v1")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.first_commit_hash in storage_folder)

    def test_download_model_local_only_multiple(self):
        # Test `main` branch
        with tempfile.TemporaryDirectory() as tmpdirname:
            # download both from branch and from commit
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                cache_dir=tmpdirname,
            )

            snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision=self.first_commit_hash,
                cache_dir=tmpdirname,
            )

            # now load from cache and make sure warning to be raised
            with self.assertWarns(Warning):
                snapshot_download(
                    f"{USER}/{REPO_NAME}",
                    cache_dir=tmpdirname,
                    local_files_only=True,
                )

        # cache multiple commits and make sure correct commit is taken
        with tempfile.TemporaryDirectory() as tmpdirname:
            # first download folder to cache it
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                cache_dir=tmpdirname,
            )

            # now load folder from another branch
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="other",
                cache_dir=tmpdirname,
            )

            # now make sure that loading "main" branch gives correct branch
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                cache_dir=tmpdirname,
                local_files_only=True,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the 2nd commit sha and not the 3rd
            self.assertTrue(self.second_commit_hash in storage_folder)

    def check_download_model_with_regex(self, regex, allow=True):
        # Test `main` branch
        allow_regex = regex if allow else None
        ignore_regex = regex if not allow else None

        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="main",
                cache_dir=tmpdirname,
                allow_regex=allow_regex,
                ignore_regex=ignore_regex,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 2)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" not in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

    def test_download_model_with_allow_regex(self):
        self.check_download_model_with_regex("*.txt")

    def test_download_model_with_allow_regex_list(self):
        self.check_download_model_with_regex(["dummy_file.txt", "dummy_file_2.txt"])

    def test_download_model_with_ignore_regex(self):
        self.check_download_model_with_regex(".gitattributes", allow=False)

    def test_download_model_with_ignore_regex_list(self):
        self.check_download_model_with_regex(["*.git*", "*.pt"], allow=False)
 def setUpClass(cls):
     cls._api = HfApi(endpoint=ENDPOINT_STAGING)
     cls._token = cls._api.login(username=USER, password=PASS)
示例#30
0
def main():
    torch.multiprocessing.set_sharing_strategy('file_system')
    torchaudio.set_audio_backend('sox_io')
    hack_isinstance()

    # get config and arguments
    args, config, backup_files = get_downstream_args()
    if args.cache_dir is not None:
        torch.hub.set_dir(args.cache_dir)

    # When torch.distributed.launch is used
    if args.local_rank is not None:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(args.backend)

    if args.mode == 'train' and args.past_exp:
        ckpt = torch.load(args.init_ckpt, map_location='cpu')

        now_use_ddp = is_initialized()
        original_use_ddp = ckpt['Args'].local_rank is not None
        assert now_use_ddp == original_use_ddp, f'{now_use_ddp} != {original_use_ddp}'

        if now_use_ddp:
            now_world = get_world_size()
            original_world = ckpt['WorldSize']
            assert now_world == original_world, f'{now_world} != {original_world}'

    if args.hub == "huggingface":
        args.from_hf_hub = True
        # Setup auth
        hf_user = os.environ.get("HF_USERNAME")
        hf_password = os.environ.get("HF_PASSWORD")
        huggingface_token = HfApi().login(username=hf_user, password=hf_password)
        HfFolder.save_token(huggingface_token)
        print(f"Logged into Hugging Face Hub with user: {hf_user}")
    
    # Save command
    if is_leader_process():
        with open(os.path.join(args.expdir, f'args_{get_time_tag()}.yaml'), 'w') as file:
            yaml.dump(vars(args), file)

        with open(os.path.join(args.expdir, f'config_{get_time_tag()}.yaml'), 'w') as file:
            yaml.dump(config, file)

        for file in backup_files:
            backup(file, args.expdir)

    # Fix seed and make backends deterministic
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed)
    if args.disable_cudnn:
        torch.backends.cudnn.enabled = False
    else:
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    runner = Runner(args, config)
    eval(f'runner.{args.mode}')()