def test_loading_undownloaded(self, tmp_path): "Test loading before ``Dataset.download()`` has been called." init(DATADIR=tmp_path) with pytest.raises(RuntimeError) as e: load_dataset('wikitext103', version='1.0.1', download=False) assert 'Did you forget to download the dataset (by specifying `download=True`)?' in str(e.value)
def test_loading_undownloaded(self, tmp_path): "Test loading before ``Dataset.download()`` has been called." init(DATADIR=tmp_path) with pytest.raises(FileNotFoundError) as e: load_dataset('wikitext103', version='1.0.1', download=False) assert 'Failed to load the dataset because some files are not found.' in str(e.value)
def test_loading_undownloaded(self, tmp_path): "Test loading before ``Dataset.download()`` has been called." init(DATADIR=tmp_path) with pytest.raises(RuntimeError) as e: load_dataset('wikitext103', version='1.0.1', download=False) assert ( 'Did you forget to download the dataset ' '(by calling this function with `download=True` for at least once)?' ) in str(e.value)
def test_subdatasets_param(self, tmp_path): "Test to see subdatasets parameter is being handled properly." init(DATADIR=tmp_path) with pytest.raises(TypeError) as e: load_dataset('wikitext103', version='1.0.1', download=True, subdatasets=123) assert str(e.value) == '\'int\' object is not iterable' subdatasets = ['train'] wikitext103_data = load_dataset('wikitext103', version='1.0.1', download=True, subdatasets=subdatasets) assert list(wikitext103_data.keys()) == subdatasets
def test_version_param(self, tmp_path): "Test to see the version parameter is being handled properly." init(DATADIR=tmp_path) with pytest.raises(TypeError) as e: load_dataset('gmb', version=1.0) assert str(e.value) == 'The version parameter must be supplied a str.' name, version = 'gmb', '' with pytest.raises(KeyError) as e: load_dataset('gmb', version=version) assert str(e.value) == ( f'\'"{version}" is not a valid PyDAX version for the dataset "{name}". ' 'You can view all valid datasets and their versions by running the function ' 'pydax.list_all_datasets().\'') name, version = 'gmb', 'fake_version' with pytest.raises(KeyError) as e: load_dataset('gmb', version=version) assert str(e.value) == ( f'\'"{version}" is not a valid PyDAX version for the dataset "{name}". ' 'You can view all valid datasets and their versions by running the function ' 'pydax.list_all_datasets().\'') # If no version specified, make sure latest version grabbed all_datasets = list_all_datasets() latest_version = str( sorted(version_parser(v) for v in all_datasets[name])[-1]) assert load_dataset('gmb') == load_dataset('gmb', version=latest_version)
def test_name_param(self, tmp_path): "Test to see the name parameter is being handled properly." init(DATADIR=tmp_path) with pytest.raises(TypeError) as e: load_dataset(123) assert str(e.value) == 'The name parameter must be supplied a str.' name = 'fake_dataset' with pytest.raises(KeyError) as e: load_dataset(name) assert str(e.value) == (f'\'"{name}" is not a valid PyDAX dataset. You can view all valid datasets and their ' 'versions by running the function pydax.list_all_datasets().\'')
def test_download_true(self, tmp_path, downloaded_gmb_dataset): "Test to see the function downloads and loads properly when download=True." init(DATADIR=tmp_path) downloaded_gmb_dataset_data = downloaded_gmb_dataset.load() gmb_data = load_dataset('gmb', version='1.0.2', download=True) assert downloaded_gmb_dataset_data == gmb_data
def test_download_false(self, tmp_path, gmb_schema): "Test to see the function loads properly when download=False and dataset was previously downloaded." init(DATADIR=tmp_path) data_dir = tmp_path / 'gmb' / '1.0.2' gmb = Dataset(gmb_schema, data_dir=data_dir, mode=Dataset.InitializationMode.DOWNLOAD_AND_LOAD) gmb_data = load_dataset('gmb', version='1.0.2', download=False) assert gmb.data == gmb_data
def test_default_dataset_schema_name(self, tmp_path, gmb_schema): "Test the default schemata name." init(DATADIR=tmp_path) data_dir = tmp_path / 'default' / 'gmb' / '1.0.2' gmb = Dataset(gmb_schema, data_dir=data_dir, mode=Dataset.InitializationMode.DOWNLOAD_AND_LOAD) _get_schemata().schemata['datasets']._schema.pop('name') # Remove the "name" key gmb_data = load_dataset('gmb', version='1.0.2', download=False) assert gmb.data == gmb_data
# This scripts tests all files in CODAIT/dax-schemata. It shouldn't report any error. import yaml import pydax pydax.init(DATASET_SCHEMA_URL='datasets.yaml', FORMAT_SCHEMA_URL='formats.yaml', LICENSE_SCHEMA_URL='licenses.yaml') with open('datasets.yaml') as f: datasets = yaml.safe_load(f) # Datasets name are the same from the schema files. This helps ensure that PyDAX doesn't miss any dataset during the # test. assert frozenset(datasets['datasets']) == frozenset(pydax.list_all_datasets()) # Sanity check. In case of all tests being skipped because of a minor error such as in formatting. assert len(pydax.list_all_datasets()) > 0 for name, versions in pydax.list_all_datasets().items(): # Versions must be the same from the schema files. This helps ensure that PyDAX doesn't miss any dataset during the # test. assert frozenset(datasets['datasets'][name]) == frozenset(versions) # Sanity check. In case of all tests being skipped because of a minor error such as in formatting. assert len(versions) > 0 for version in versions: # Print dataset info. This also examines relevant portion in license.yaml print(pydax.get_dataset_metadata(name, version, human=True), end='\n\n') pydax.load_dataset(name=name, version=version, subdatasets=None)