def test_push_dataset_dict_to_hub_custom_features(self):
        features = Features({
            "x": Value("int64"),
            "y": ClassLabel(names=["neg", "pos"])
        })
        ds = Dataset.from_dict({
            "x": [1, 2, 3],
            "y": [0, 0, 1]
        },
                               features=features)

        local_ds = DatasetDict({"test": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["test"].features.keys()),
                                 list(hub_ds["test"].features.keys()))
            self.assertDictEqual(local_ds["test"].features,
                                 hub_ds["test"].features)
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")
    def test_push_dataset_dict_to_hub_name_without_namespace(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

            # Ensure that there is a single file on the repository that has the correct name
            files = sorted(
                self._api.list_repo_files(ds_name, repo_type="dataset"))
            self.assertListEqual(files, [
                ".gitattributes", "data/train-00000-of-00001.parquet",
                "dataset_infos.json"
            ])
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  repo_type="dataset")
    def test_push_dataset_dict_to_hub_multiple_files(self):
        ds = Dataset.from_dict({
            "x": list(range(1000)),
            "y": list(range(1000))
        })

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name,
                                 token=self._token,
                                 shard_size=500 << 5)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

            # Ensure that there are two files on the repository that have the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(files, [
                ".gitattributes", "data/train-00000-of-00002.parquet",
                "data/train-00001-of-00002.parquet"
            ])
        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")
Beispiel #4
0
    def test_push_dataset_dict_to_hub_no_token(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features)

            # Ensure that there is a single file on the repository that has the correct name
            files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset"))
            self.assertTrue(
                all(
                    fnmatch.fnmatch(file, expected_file)
                    for file, expected_file in zip(
                        files, [".gitattributes", "data/train-00000-of-00001-*.parquet", "dataset_infos.json"]
                    )
                )
            )
        finally:
            self.cleanup_repo(ds_name)
Beispiel #5
0
    def test_push_dataset_dict_to_hub_multiple_files(self):
        ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))})

        local_ds = DatasetDict({"train": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, token=self._token, max_shard_size="16KB")
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features)

            # Ensure that there are two files on the repository that have the correct name
            files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token))
            self.assertTrue(
                all(
                    fnmatch.fnmatch(file, expected_file)
                    for file, expected_file in zip(
                        files,
                        [
                            ".gitattributes",
                            "data/train-00000-of-00002-*.parquet",
                            "data/train-00001-of-00002-*.parquet",
                            "dataset_infos.json",
                        ],
                    )
                )
            )
        finally:
            self.cleanup_repo(ds_name)
Beispiel #6
0
    def test_push_dataset_dict_to_hub_datasets_with_different_features(self):
        ds_train = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})
        ds_test = Dataset.from_dict({"x": [True, False, True], "y": ["a", "b", "c"]})

        local_ds = DatasetDict({"train": ds_train, "test": ds_test})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            with self.assertRaises(ValueError):
                local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token)
        except AssertionError:
            self.cleanup_repo(ds_name)
            raise
Beispiel #7
0
    def test_push_dataset_dict_to_hub_custom_splits(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

        local_ds = DatasetDict({"random": ds})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
        try:
            local_ds.push_to_hub(ds_name, token=self._token)
            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["random"].features.keys()), list(hub_ds["random"].features.keys()))
            self.assertDictEqual(local_ds["random"].features, hub_ds["random"].features)
        finally:
            self.cleanup_repo(ds_name)
Beispiel #8
0
    def test_push_streaming_dataset_dict_to_hub(self):
        ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})
        local_ds = DatasetDict({"train": ds})
        with tempfile.TemporaryDirectory() as tmp:
            local_ds.save_to_disk(tmp)
            local_ds = load_dataset(tmp, streaming=True)

            ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
            try:
                local_ds.push_to_hub(ds_name, token=self._token)
                hub_ds = load_dataset(ds_name, download_mode="force_redownload")

                self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
                self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys()))
                self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features)
            finally:
                self.cleanup_repo(ds_name)
    def test_push_dataset_dict_to_hub_overwrite_files(self):
        ds = Dataset.from_dict({
            "x": list(range(1000)),
            "y": list(range(1000))
        })
        ds2 = Dataset.from_dict({"x": list(range(100)), "y": list(range(100))})

        local_ds = DatasetDict({"train": ds, "random": ds2})

        ds_name = f"{USER}/test-{int(time.time() * 10e3)}"

        # Push to hub two times, but the second time with a larger amount of files.
        # Verify that the new files contain the correct dataset.
        try:
            local_ds.push_to_hub(ds_name, token=self._token)

            with tempfile.TemporaryDirectory() as tmp:
                # Add a file starting with "data" to ensure it doesn't get deleted.
                path = Path(tmp) / "datafile.txt"
                with open(path, "w") as f:
                    f.write("Bogus file")

                self._api.upload_file(str(path),
                                      path_in_repo="datafile.txt",
                                      repo_id=ds_name,
                                      repo_type="dataset",
                                      token=self._token)

            local_ds.push_to_hub(ds_name,
                                 token=self._token,
                                 shard_size=500 << 5)

            # Ensure that there are two files on the repository that have the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(
                files,
                [
                    ".gitattributes",
                    "data/random-00000-of-00001.parquet",
                    "data/train-00000-of-00002.parquet",
                    "data/train-00001-of-00002.parquet",
                    "datafile.txt",
                    "dataset_infos.json",
                ],
            )

            self._api.delete_file("datafile.txt",
                                  repo_id=ds_name,
                                  repo_type="dataset",
                                  token=self._token)

            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")

        # Push to hub two times, but the second time with fewer files.
        # Verify that the new files contain the correct dataset and that non-necessary files have been deleted.
        try:
            local_ds.push_to_hub(ds_name,
                                 token=self._token,
                                 shard_size=500 << 5)

            with tempfile.TemporaryDirectory() as tmp:
                # Add a file starting with "data" to ensure it doesn't get deleted.
                path = Path(tmp) / "datafile.txt"
                with open(path, "w") as f:
                    f.write("Bogus file")

                self._api.upload_file(str(path),
                                      path_in_repo="datafile.txt",
                                      repo_id=ds_name,
                                      repo_type="dataset",
                                      token=self._token)

            local_ds.push_to_hub(ds_name, token=self._token)

            # Ensure that there are two files on the repository that have the correct name
            files = sorted(
                self._api.list_repo_files(ds_name,
                                          repo_type="dataset",
                                          token=self._token))
            self.assertListEqual(
                files,
                [
                    ".gitattributes",
                    "data/random-00000-of-00001.parquet",
                    "data/train-00000-of-00001.parquet",
                    "datafile.txt",
                    "dataset_infos.json",
                ],
            )

            # Keeping the "datafile.txt" breaks the load_dataset to think it's a text-based dataset
            self._api.delete_file("datafile.txt",
                                  repo_id=ds_name,
                                  repo_type="dataset",
                                  token=self._token)

            hub_ds = load_dataset(ds_name, download_mode="force_redownload")

            self.assertDictEqual(local_ds.column_names, hub_ds.column_names)
            self.assertListEqual(list(local_ds["train"].features.keys()),
                                 list(hub_ds["train"].features.keys()))
            self.assertDictEqual(local_ds["train"].features,
                                 hub_ds["train"].features)

        finally:
            self._api.delete_repo(ds_name.split("/")[1],
                                  organization=ds_name.split("/")[0],
                                  token=self._token,
                                  repo_type="dataset")