Ejemplo n.º 1
0
    def check_download_model_with_regex(self, regex, allow=True):
        # Test `main` branch
        allow_regex = regex if allow else None
        ignore_regex = regex if not allow else None

        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="main",
                cache_dir=tmpdirname,
                allow_regex=allow_regex,
                ignore_regex=ignore_regex,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 2)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" not in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"),
                      "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)
Ejemplo n.º 2
0
    def test_download_model(self):
        # Test `main` branch
        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(f"{USER}/{REPO_NAME}",
                                               revision="main",
                                               cache_dir=tmpdirname)

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"),
                      "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

        # Test with specific revision
        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision=self.first_commit_hash,
                cache_dir=tmpdirname,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 2)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"),
                      "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v1")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.first_commit_hash in storage_folder)
Ejemplo n.º 3
0
    def _from_pretrained(
        cls,
        model_id,
        revision,
        cache_dir,
        force_download,
        proxies,
        resume_download,
        local_files_only,
        use_auth_token,
        **model_kwargs,
    ):
        """Here we just call from_pretrained_keras function so both the mixin and functional APIs stay in sync.

        TODO - Some args above aren't used since we are calling snapshot_download instead of hf_hub_download.
        """
        if is_tf_available():
            import tensorflow as tf
        else:
            raise ImportError(
                "Called a Tensorflow-specific function but could not import it."
            )

        # TODO - Figure out what to do about these config values. Config is not going to be needed to load model
        cfg = model_kwargs.pop("config", None)

        # Root is either a local filepath matching model_id or a cached snapshot
        if not os.path.isdir(model_id):
            storage_folder = snapshot_download(
                repo_id=model_id, revision=revision, cache_dir=cache_dir
            )
        else:
            storage_folder = model_id

        model = tf.keras.models.load_model(storage_folder, **model_kwargs)

        # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
        model.config = cfg

        return model
Ejemplo n.º 4
0
    def test_download_model_local_only_multiple(self):
        # Test `main` branch
        with tempfile.TemporaryDirectory() as tmpdirname:
            # download both from branch and from commit
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                cache_dir=tmpdirname,
            )

            snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision=self.first_commit_hash,
                cache_dir=tmpdirname,
            )

            # now load from cache and make sure warning to be raised
            with self.assertWarns(Warning):
                snapshot_download(
                    f"{USER}/{REPO_NAME}",
                    cache_dir=tmpdirname,
                    local_files_only=True,
                )

        # cache multiple commits and make sure correct commit is taken
        with tempfile.TemporaryDirectory() as tmpdirname:
            # first download folder to cache it
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                cache_dir=tmpdirname,
            )

            # now load folder from another branch
            snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="other",
                cache_dir=tmpdirname,
            )

            # now make sure that loading "main" branch gives correct branch
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                cache_dir=tmpdirname,
                local_files_only=True,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"),
                      "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the 2nd commit sha and not the 3rd
            self.assertTrue(self.second_commit_hash in storage_folder)
Ejemplo n.º 5
0
    def test_download_private_model(self):
        self._api.update_repo_visibility(token=self._token,
                                         name=REPO_NAME,
                                         private=True)

        # Test download fails without token
        with tempfile.TemporaryDirectory() as tmpdirname:
            with self.assertRaisesRegex(requests.exceptions.HTTPError,
                                        "404 Client Error"):
                _ = snapshot_download(f"{USER}/{REPO_NAME}",
                                      revision="main",
                                      cache_dir=tmpdirname)

        # Test we can download with token from cache
        with tempfile.TemporaryDirectory() as tmpdirname:
            HfFolder.save_token(self._token)
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="main",
                cache_dir=tmpdirname,
                use_auth_token=True,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"),
                      "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

        # Test we can download with explicit token
        with tempfile.TemporaryDirectory() as tmpdirname:
            storage_folder = snapshot_download(
                f"{USER}/{REPO_NAME}",
                revision="main",
                cache_dir=tmpdirname,
                use_auth_token=self._token,
            )

            # folder contains the two files contributed and the .gitattributes
            folder_contents = os.listdir(storage_folder)
            self.assertEqual(len(folder_contents), 3)
            self.assertTrue("dummy_file.txt" in folder_contents)
            self.assertTrue("dummy_file_2.txt" in folder_contents)
            self.assertTrue(".gitattributes" in folder_contents)

            with open(os.path.join(storage_folder, "dummy_file.txt"),
                      "r") as f:
                contents = f.read()
                self.assertEqual(contents, "v2")

            # folder name contains the revision's commit sha.
            self.assertTrue(self.second_commit_hash in storage_folder)

        self._api.update_repo_visibility(token=self._token,
                                         name=REPO_NAME,
                                         private=False)