def test_filter_datasets_by_task_ids(self): _api = HfApi() f = DatasetFilter(task_ids="automatic-speech-recognition") datasets = _api.list_datasets(filter=f) self.assertGreater(len(datasets), 0) self.assertTrue( "task_ids:automatic-speech-recognition" in datasets[0].tags)
def test_filter_models_with_cardData(self): _api = HfApi() models = _api.list_models(filter="co2_eq_emissions", cardData=True) self.assertTrue([hasattr(model, "cardData") for model in models]) models = _api.list_models(filter="co2_eq_emissions") self.assertTrue( all([not hasattr(model, "cardData") for model in models]))
def test_list_datasets_full(self): _api = HfApi() datasets = _api.list_datasets(full=True) self.assertGreater(len(datasets), 100) dataset = datasets[0] self.assertIsInstance(dataset, DatasetInfo) self.assertTrue(any(dataset.cardData for dataset in datasets))
def hf_token(hf_api: HfApi): hf_token = hf_api.login(username=USER, password=PASS) yield hf_token try: hf_api.logout(hf_token) except requests.exceptions.HTTPError: pass
def test_filter_datasets_by_task_categories(self): _api = HfApi() f = DatasetFilter(task_categories="audio-classification") datasets = _api.list_datasets(filter=f) self.assertGreater(len(datasets), 0) self.assertTrue( "task_categories:audio-classification" in datasets[0].tags)
def __init__(self, username=None, password=None): self.api = HfApi() if username and password: self.token = self.api.login(username, password) elif username or password: print( 'Only a username or password was entered. You should include both to get authorized access' )
def test_filter_datasets_by_author_and_name(self): _api = HfApi() f = DatasetFilter(author="huggingface", dataset_name="DataMeasurementsFiles") datasets = _api.list_datasets(filter=f) self.assertEqual(len(datasets), 1) self.assertTrue("huggingface" in datasets[0].author) self.assertTrue("DataMeasurementsFiles" in datasets[0].id)
def test_filter_emissions_with_min(self): _api = HfApi() models = _api.list_models(emissions_thresholds=(5, None), cardData=True) self.assertTrue( all([ model.cardData["co2_eq_emissions"] >= 5 for model in models if isinstance(model.cardData["co2_eq_emissions"], (float, int)) ]))
def test_filter_models_with_task(self): _api = HfApi() f = ModelFilter(task="fill-mask", model_name="albert-base-v2") models = _api.list_models(filter=f) self.assertTrue("fill-mask" == models[0].pipeline_tag) self.assertTrue("albert-base-v2" in models[0].modelId) f = ModelFilter(task="dummytask") models = _api.list_models(filter=f) self.assertGreater(1, len(models))
def test_filter_models_by_language(self): _api = HfApi() f_fr = ModelFilter(language="fr") res_fr = _api.list_models(filter=f_fr) f_en = ModelFilter(language="en") res_en = _api.list_models(filter=f_en) assert len(res_fr) != len(res_en)
def test_list_models_with_config(self): _api = HfApi() models = _api.list_models(filter="adapter-transformers", fetch_config=True, limit=20) found_configs = 0 for model in models: if model.config: found_configs = found_configs + 1 self.assertGreater(found_configs, 0)
def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]: """Find models that can accept src_lang as input and return tgt_lang as output.""" prefix = "Helsinki-NLP/opus-mt-" api = HfApi() model_list = api.list_models() model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")] src_and_targ = [ remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m ] # + cant be loaded. matching = [f"{prefix}{a}-{b}" for (a, b) in src_and_targ if src_lang in a and tgt_lang in b] return matching
def test_model_info_with_security(self): _api = HfApi() model = _api.model_info( repo_id=DUMMY_MODEL_ID, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, securityStatus=True, ) self.assertEqual( getattr(model, "securityStatus"), {"containsInfected": False}, )
def test_filter_models_with_library(self): _api = HfApi() f = ModelFilter("microsoft", model_name="wavlm-base-sd", library="tensorflow") models = _api.list_models(filter=f) self.assertGreater(1, len(models)) f = ModelFilter("microsoft", model_name="wavlm-base-sd", library="pytorch") models = _api.list_models(filter=f) self.assertGreater(len(models), 0)
def test_filter_datasets_by_language(self): _api = HfApi() f = DatasetFilter(languages="en") datasets = _api.list_datasets(filter=f) self.assertGreater(len(datasets), 0) self.assertTrue("languages:en" in datasets[0].tags) args = DatasetSearchArguments() f = DatasetFilter(languages=(args.languages.en, args.languages.fr)) datasets = _api.list_datasets(filter=f) self.assertGreater(len(datasets), 0) self.assertTrue("languages:en" in datasets[0].tags) self.assertTrue("languages:fr" in datasets[0].tags)
def test_model_info(self): _api = HfApi() model = _api.model_info(repo_id=DUMMY_MODEL_ID) self.assertIsInstance(model, ModelInfo) self.assertNotEqual(model.sha, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) # One particular commit (not the top of `main`) model = _api.model_info( repo_id=DUMMY_MODEL_ID, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) self.assertIsInstance(model, ModelInfo) self.assertEqual(model.sha, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT)
def test_list_repo_files(self): _api = HfApi() files = _api.list_repo_files(repo_id=DUMMY_MODEL_ID) expected_files = [ ".gitattributes", "README.md", "config.json", "flax_model.msgpack", "merges.txt", "pytorch_model.bin", "tf_model.h5", "vocab.json", ] self.assertListEqual(files, expected_files)
def run(self): token = HfFolder.get_token() if token is None: print("Not logged in") exit() HfFolder.delete_token() HfApi.unset_access_token() try: self._api.logout(token) except HTTPError as e: # Logging out with an access token will return a client error. if not e.response.status_code == 400: raise e print("Successfully logged out.")
def test_list_models_complex_query(self): # Let's list the 10 most recent models # with tags "bert" and "jax", # ordered by last modified date. _api = HfApi() models = _api.list_models(filter=("bert", "jax"), sort="lastModified", direction=-1, limit=10) # we have at least 1 models self.assertGreater(len(models), 1) self.assertLessEqual(len(models), 10) model = models[0] self.assertIsInstance(model, ModelInfo) self.assertTrue(all(tag in model.tags for tag in ["bert", "jax"]))
def login_password_event(t): username = input_widget.value password = password_widget.value # Erase password and clear value to make sure it's not saved in the notebook. password_widget.value = "" clear_output() _login(HfApi(), username=username, password=password)
def test_filter_models_with_complex_query(self): _api = HfApi() args = ModelSearchArguments() f = ModelFilter( task=args.pipeline_tag.TextClassification, library=[args.library.PyTorch, args.library.TensorFlow], ) models = _api.list_models(filter=f) self.assertGreater(len(models), 1) self.assertTrue([ "text-classification" in model.pipeline_tag or "text-classification" in model.tags for model in models ]) self.assertTrue([ "pytorch" in model.tags and "tf" in model.tags for model in models ])
def test_filter_datasets_with_cardData(self): _api = HfApi() datasets = _api.list_datasets(cardData=True) self.assertGreater( sum([ getattr(dataset, "cardData", None) is not None for dataset in datasets ]), 0, ) datasets = _api.list_datasets() self.assertTrue( all([ getattr(dataset, "cardData", None) is None for dataset in datasets ]))
def test_name_org_deprecation_warning(): # test that the right warning is raised when passing name to # {create, delete}_repo and update_repo_visibility api = HfApi(endpoint=ENDPOINT_STAGING) REPO_NAME = repo_name("crud") args = [ ("create_repo", {}), ("update_repo_visibility", { "private": False }), ("delete_repo", {}), ] for method, kwargs in args: with pytest.warns( FutureWarning, match=re.escape( "`name` and `organization` input arguments are deprecated" ), ): getattr(api, method)(name=REPO_NAME, token=TOKEN, repo_type=REPO_TYPE_MODEL, **kwargs)
def test_name_org_deprecation_error(): # tests that the right error is raised when passing both name and repo_id # to {create, delete}_repo and update_repo_visibility api = HfApi(endpoint=ENDPOINT_STAGING) REPO_NAME = repo_name("crud") args = [ ("create_repo", {}), ("update_repo_visibility", { "private": False }), ("delete_repo", {}), ] for method, kwargs in args: with pytest.raises( ValueError, match=re.escape("Only pass `repo_id`"), ): getattr(api, method)( repo_id="test", name=REPO_NAME, token=TOKEN, repo_type=REPO_TYPE_MODEL, **kwargs, ) for method, kwargs in args: with pytest.raises(ValueError, match="No name provided"): getattr(api, method)(token=TOKEN, repo_type=REPO_TYPE_MODEL, **kwargs)
def test_dataset_info(self): _api = HfApi() dataset = _api.dataset_info(repo_id=DUMMY_DATASET_ID) self.assertTrue( isinstance(dataset.cardData, dict) and len(dataset.cardData) > 0) self.assertTrue( isinstance(dataset.siblings, list) and len(dataset.siblings) > 0) self.assertIsInstance(dataset, DatasetInfo) self.assertNotEqual(dataset.sha, DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT) dataset = _api.dataset_info( repo_id=DUMMY_DATASET_ID, revision=DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT, ) self.assertIsInstance(dataset, DatasetInfo) self.assertEqual(dataset.sha, DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT)
def test_tags(self): _api = HfApi() path = f"{_api.endpoint}/api/models-tags-by-type" r = requests.get(path) r.raise_for_status() d = r.json() o = ModelTags(d) for kind in ["library", "language", "license", "dataset", "pipeline_tag"]: self.assertTrue(len(getattr(o, kind).keys()) > 0)
def test_hub_configs(self): """I put require_torch_gpu cause I only want this to run with self-scheduled.""" model_list = HfApi().list_models() org = "sshleifer" model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"] failures = [] for m in model_ids: if m in allowed_to_be_broken: continue try: AutoConfig.from_pretrained(m) except Exception: failures.append(m) assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
def test_tags(self): _api = HfApi() path = f"{_api.endpoint}/api/datasets-tags-by-type" r = requests.get(path) r.raise_for_status() d = r.json() o = DatasetTags(d) for kind in [ "languages", "multilinguality", "language_creators", "task_categories", "size_categories", "benchmark", "task_ids", "licenses", ]: self.assertTrue(len(getattr(o, kind).keys()) > 0)
def push_to_hub( save_directory: Optional[str], model_id: Optional[str] = None, repo_url: Optional[str] = None, commit_message: Optional[str] = "add model", organization: Optional[str] = None, private: bool = None, ) -> str: """ Parameters: save_directory (:obj:`Union[str, os.PathLike]`): Directory having model weights & config. model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`): Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory` repo_url (:obj:`str`, `optional`): Specify this in case you want to push to existing repo in hub. organization (:obj:`str`, `optional`): Organization in which you want to push your model. private (:obj:`bool`, `optional`): private: Whether the model repo should be private (requires a paid huggingface.co account) commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`): Message to commit while pushing Returns: url to commit on remote repo. """ if model_id is None: model_id = save_directory token = HfFolder.get_token() if repo_url is None: repo_url = HfApi().create_repo( token, model_id, organization=organization, private=private, repo_type=None, exist_ok=True, ) repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token) return repo.push_to_hub(commit_message=commit_message)
def test_repo_id_no_warning(): # tests that passing repo_id as positional arg doesn't raise any warnings # for {create, delete}_repo and update_repo_visibility api = HfApi(endpoint=ENDPOINT_STAGING) REPO_NAME = repo_name("crud") args = [ ("create_repo", {}), ("update_repo_visibility", { "private": False }), ("delete_repo", {}), ] for method, kwargs in args: with warnings.catch_warnings(record=True) as record: getattr(api, method)(REPO_NAME, token=TOKEN, repo_type=REPO_TYPE_MODEL, **kwargs) assert not len(record)