def test_file_format_save_restore( tmp_path: pathlib.Path, file_format: file_adapters.FileFormat, ): builder = testing.DummyDataset(data_dir=tmp_path, file_format=file_format) assert isinstance(builder.info.file_format, file_adapters.FileFormat) assert builder.info.file_format is file_format builder.download_and_prepare() # When restoring the builder, we do not provide the `file_format=` # yet it is correctly restored builder2 = testing.DummyDataset(data_dir=tmp_path) assert builder2.info.file_format is file_format # Explicitly passing the correct format is accepted. builder3 = testing.DummyDataset(data_dir=tmp_path, file_format=file_format) assert builder3.info.file_format is file_format # Providing an inconsistent format is rejected. with pytest.raises(ValueError, match="File format is already set to"): different_file_format = { file_adapters.FileFormat.TFRECORD: file_adapters.FileFormat.RIEGELI, file_adapters.FileFormat.RIEGELI: file_adapters.FileFormat.TFRECORD, }[file_format] testing.DummyDataset(data_dir=tmp_path, file_format=different_file_format)
def test_file_format_values(tmp_path: pathlib.Path): # Default file format builder = testing.DummyDataset(data_dir=tmp_path, file_format=None) assert builder.info.file_format == file_adapters.FileFormat.TFRECORD # str accepted builder = testing.DummyDataset(data_dir=tmp_path, file_format="riegeli") assert builder.info.file_format == file_adapters.FileFormat.RIEGELI # file_adapters.FileFormat accepted builder = testing.DummyDataset( data_dir=tmp_path, file_format=file_adapters.FileFormat.RIEGELI) assert builder.info.file_format == file_adapters.FileFormat.RIEGELI # Unknown value with pytest.raises(ValueError, match="is not a valid FileFormat"): testing.DummyDataset(data_dir=tmp_path, file_format="arrow")
def test_compute_split_info(tmp_path): builder = testing.DummyDataset(data_dir=tmp_path) builder.download_and_prepare() split_infos = compute_split_utils.compute_split_info( data_dir=tmp_path / builder.info.full_name, out_dir=tmp_path, ) assert [s.to_proto() for s in split_infos ] == [s.to_proto() for s in builder.info.splits.values()] # Split info are correctly saved split_path = tmp_path / compute_split_utils._out_filename('train') split_info = compute_split_utils._split_info_from_path(split_path) assert builder.info.splits['train'].to_proto() == split_info.to_proto()
def test_write_metadata( tmp_path: pathlib.Path, file_format, ): tmp_path = utils.as_path(tmp_path) src_builder = testing.DummyDataset( data_dir=tmp_path / 'origin', file_format=file_format, ) src_builder.download_and_prepare() dst_dir = tmp_path / 'copy' dst_dir.mkdir() # Copy all the tfrecord files, but not the dataset info for f in src_builder.data_path.iterdir(): if naming.FilenameInfo.is_valid(f.name): f.copy(dst_dir / f.name) metadata_path = dst_dir / 'dataset_info.json' if file_format is None: split_infos = list(src_builder.info.splits.values()) else: split_infos = None # Auto-compute split infos assert not metadata_path.exists() write_metadata_utils.write_metadata( data_dir=dst_dir, features=src_builder.info.features, split_infos=split_infos, description='my test description.') assert metadata_path.exists() # After metadata are written, builder can be restored from the directory builder = read_only_builder.builder_from_directory(dst_dir) assert builder.name == 'dummy_dataset' assert builder.version == '1.0.0' assert set(builder.info.splits) == {'train'} assert builder.info.splits['train'].num_examples == 3 assert builder.info.description == 'my test description.' # Values are the same src_ds = src_builder.as_dataset(split='train') ds = builder.as_dataset(split='train') assert list(src_ds.as_numpy_iterator()) == list(ds.as_numpy_iterator())
def test_compute_split_info(tmp_path): builder = testing.DummyDataset(data_dir=tmp_path) builder.download_and_prepare() filename_template = naming.ShardedFileTemplate( dataset_name=builder.name, data_dir=builder.data_dir, filetype_suffix=builder.info.file_format.file_suffix) split_infos = compute_split_utils.compute_split_info( out_dir=tmp_path, filename_template=filename_template, ) assert [s.to_proto() for s in split_infos ] == [s.to_proto() for s in builder.info.splits.values()] # Split info are correctly saved filename_template = filename_template.replace( data_dir=tmp_path, split='train') split_info = compute_split_utils._split_info_from_path(filename_template) assert builder.info.splits['train'].to_proto() == split_info.to_proto()
def dummy_builder(tmp_path_factory): """Dummy dataset shared accross tests.""" data_dir = tmp_path_factory.mktemp('datasets') builder = testing.DummyDataset(data_dir=data_dir) builder.download_and_prepare() yield builder