def test_push_dataset_dict_to_hub_custom_features(self): features = Features({ "x": Value("int64"), "y": ClassLabel(names=["neg", "pos"]) }) ds = Dataset.from_dict({ "x": [1, 2, 3], "y": [0, 0, 1] }, features=features) local_ds = DatasetDict({"test": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["test"].features.keys()), list(hub_ds["test"].features.keys())) self.assertDictEqual(local_ds["test"].features, hub_ds["test"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset")
def test_push_dataset_dict_to_hub_name_without_namespace(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there is a single file on the repository that has the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset")) self.assertListEqual(files, [ ".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json" ]) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], repo_type="dataset")
def test_push_dataset_dict_to_hub_multiple_files(self): ds = Dataset.from_dict({ "x": list(range(1000)), "y": list(range(1000)) }) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token, shard_size=500 << 5) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there are two files on the repository that have the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual(files, [ ".gitattributes", "data/train-00000-of-00002.parquet", "data/train-00001-of-00002.parquet" ]) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset")
def test_push_dataset_dict_to_hub_no_token(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there is a single file on the repository that has the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset")) self.assertTrue( all( fnmatch.fnmatch(file, expected_file) for file, expected_file in zip( files, [".gitattributes", "data/train-00000-of-00001-*.parquet", "dataset_infos.json"] ) ) ) finally: self.cleanup_repo(ds_name)
def test_push_dataset_dict_to_hub_multiple_files(self): ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) local_ds = DatasetDict({"train": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token, max_shard_size="16KB") hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) # Ensure that there are two files on the repository that have the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertTrue( all( fnmatch.fnmatch(file, expected_file) for file, expected_file in zip( files, [ ".gitattributes", "data/train-00000-of-00002-*.parquet", "data/train-00001-of-00002-*.parquet", "dataset_infos.json", ], ) ) ) finally: self.cleanup_repo(ds_name)
def test_push_dataset_dict_to_hub_datasets_with_different_features(self): ds_train = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) ds_test = Dataset.from_dict({"x": [True, False, True], "y": ["a", "b", "c"]}) local_ds = DatasetDict({"train": ds_train, "test": ds_test}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: with self.assertRaises(ValueError): local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token) except AssertionError: self.cleanup_repo(ds_name) raise
def test_push_dataset_dict_to_hub_custom_splits(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"random": ds}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["random"].features.keys()), list(hub_ds["random"].features.keys())) self.assertDictEqual(local_ds["random"].features, hub_ds["random"].features) finally: self.cleanup_repo(ds_name)
def test_push_streaming_dataset_dict_to_hub(self): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) local_ds = DatasetDict({"train": ds}) with tempfile.TemporaryDirectory() as tmp: local_ds.save_to_disk(tmp) local_ds = load_dataset(tmp, streaming=True) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" try: local_ds.push_to_hub(ds_name, token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) finally: self.cleanup_repo(ds_name)
def test_push_dataset_dict_to_hub_overwrite_files(self): ds = Dataset.from_dict({ "x": list(range(1000)), "y": list(range(1000)) }) ds2 = Dataset.from_dict({"x": list(range(100)), "y": list(range(100))}) local_ds = DatasetDict({"train": ds, "random": ds2}) ds_name = f"{USER}/test-{int(time.time() * 10e3)}" # Push to hub two times, but the second time with a larger amount of files. # Verify that the new files contain the correct dataset. try: local_ds.push_to_hub(ds_name, token=self._token) with tempfile.TemporaryDirectory() as tmp: # Add a file starting with "data" to ensure it doesn't get deleted. path = Path(tmp) / "datafile.txt" with open(path, "w") as f: f.write("Bogus file") self._api.upload_file(str(path), path_in_repo="datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) local_ds.push_to_hub(ds_name, token=self._token, shard_size=500 << 5) # Ensure that there are two files on the repository that have the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual( files, [ ".gitattributes", "data/random-00000-of-00001.parquet", "data/train-00000-of-00002.parquet", "data/train-00001-of-00002.parquet", "datafile.txt", "dataset_infos.json", ], ) self._api.delete_file("datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset") # Push to hub two times, but the second time with fewer files. # Verify that the new files contain the correct dataset and that non-necessary files have been deleted. try: local_ds.push_to_hub(ds_name, token=self._token, shard_size=500 << 5) with tempfile.TemporaryDirectory() as tmp: # Add a file starting with "data" to ensure it doesn't get deleted. path = Path(tmp) / "datafile.txt" with open(path, "w") as f: f.write("Bogus file") self._api.upload_file(str(path), path_in_repo="datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) local_ds.push_to_hub(ds_name, token=self._token) # Ensure that there are two files on the repository that have the correct name files = sorted( self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) self.assertListEqual( files, [ ".gitattributes", "data/random-00000-of-00001.parquet", "data/train-00000-of-00001.parquet", "datafile.txt", "dataset_infos.json", ], ) # Keeping the "datafile.txt" breaks the load_dataset to think it's a text-based dataset self._api.delete_file("datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) hub_ds = load_dataset(ds_name, download_mode="force_redownload") self.assertDictEqual(local_ds.column_names, hub_ds.column_names) self.assertListEqual(list(local_ds["train"].features.keys()), list(hub_ds["train"].features.keys())) self.assertDictEqual(local_ds["train"].features, hub_ds["train"].features) finally: self._api.delete_repo(ds_name.split("/")[1], organization=ds_name.split("/")[0], token=self._token, repo_type="dataset")