Exemplo n.º 1
0
def test_file_format_save_restore(
    tmp_path: pathlib.Path,
    file_format: file_adapters.FileFormat,
):
    builder = testing.DummyDataset(data_dir=tmp_path, file_format=file_format)

    assert isinstance(builder.info.file_format, file_adapters.FileFormat)
    assert builder.info.file_format is file_format

    builder.download_and_prepare()

    # When restoring the builder, we do not provide the `file_format=`
    # yet it is correctly restored
    builder2 = testing.DummyDataset(data_dir=tmp_path)
    assert builder2.info.file_format is file_format

    # Explicitly passing the correct format is accepted.
    builder3 = testing.DummyDataset(data_dir=tmp_path, file_format=file_format)
    assert builder3.info.file_format is file_format

    # Providing an inconsistent format is rejected.
    with pytest.raises(ValueError, match="File format is already set to"):
        different_file_format = {
            file_adapters.FileFormat.TFRECORD:
            file_adapters.FileFormat.RIEGELI,
            file_adapters.FileFormat.RIEGELI:
            file_adapters.FileFormat.TFRECORD,
        }[file_format]
        testing.DummyDataset(data_dir=tmp_path,
                             file_format=different_file_format)
Exemplo n.º 2
0
def test_file_format_values(tmp_path: pathlib.Path):
    # Default file format
    builder = testing.DummyDataset(data_dir=tmp_path, file_format=None)
    assert builder.info.file_format == file_adapters.FileFormat.TFRECORD

    # str accepted
    builder = testing.DummyDataset(data_dir=tmp_path, file_format="riegeli")
    assert builder.info.file_format == file_adapters.FileFormat.RIEGELI

    # file_adapters.FileFormat accepted
    builder = testing.DummyDataset(
        data_dir=tmp_path, file_format=file_adapters.FileFormat.RIEGELI)
    assert builder.info.file_format == file_adapters.FileFormat.RIEGELI

    # Unknown value
    with pytest.raises(ValueError, match="is not a valid FileFormat"):
        testing.DummyDataset(data_dir=tmp_path, file_format="arrow")
Exemplo n.º 3
0
def test_compute_split_info(tmp_path):
  builder = testing.DummyDataset(data_dir=tmp_path)
  builder.download_and_prepare()

  split_infos = compute_split_utils.compute_split_info(
      data_dir=tmp_path / builder.info.full_name,
      out_dir=tmp_path,
  )

  assert [s.to_proto() for s in split_infos
         ] == [s.to_proto() for s in builder.info.splits.values()]

  # Split info are correctly saved
  split_path = tmp_path / compute_split_utils._out_filename('train')
  split_info = compute_split_utils._split_info_from_path(split_path)
  assert builder.info.splits['train'].to_proto() == split_info.to_proto()
Exemplo n.º 4
0
def test_write_metadata(
    tmp_path: pathlib.Path,
    file_format,
):
  tmp_path = utils.as_path(tmp_path)

  src_builder = testing.DummyDataset(
      data_dir=tmp_path / 'origin',
      file_format=file_format,
  )
  src_builder.download_and_prepare()

  dst_dir = tmp_path / 'copy'
  dst_dir.mkdir()

  # Copy all the tfrecord files, but not the dataset info
  for f in src_builder.data_path.iterdir():
    if naming.FilenameInfo.is_valid(f.name):
      f.copy(dst_dir / f.name)

  metadata_path = dst_dir / 'dataset_info.json'

  if file_format is None:
    split_infos = list(src_builder.info.splits.values())
  else:
    split_infos = None  # Auto-compute split infos

  assert not metadata_path.exists()
  write_metadata_utils.write_metadata(
      data_dir=dst_dir,
      features=src_builder.info.features,
      split_infos=split_infos,
      description='my test description.')
  assert metadata_path.exists()

  # After metadata are written, builder can be restored from the directory
  builder = read_only_builder.builder_from_directory(dst_dir)
  assert builder.name == 'dummy_dataset'
  assert builder.version == '1.0.0'
  assert set(builder.info.splits) == {'train'}
  assert builder.info.splits['train'].num_examples == 3
  assert builder.info.description == 'my test description.'

  # Values are the same
  src_ds = src_builder.as_dataset(split='train')
  ds = builder.as_dataset(split='train')
  assert list(src_ds.as_numpy_iterator()) == list(ds.as_numpy_iterator())
def test_compute_split_info(tmp_path):
  builder = testing.DummyDataset(data_dir=tmp_path)
  builder.download_and_prepare()

  filename_template = naming.ShardedFileTemplate(
      dataset_name=builder.name,
      data_dir=builder.data_dir,
      filetype_suffix=builder.info.file_format.file_suffix)
  split_infos = compute_split_utils.compute_split_info(
      out_dir=tmp_path,
      filename_template=filename_template,
  )

  assert [s.to_proto() for s in split_infos
         ] == [s.to_proto() for s in builder.info.splits.values()]

  # Split info are correctly saved
  filename_template = filename_template.replace(
      data_dir=tmp_path, split='train')
  split_info = compute_split_utils._split_info_from_path(filename_template)
  assert builder.info.splits['train'].to_proto() == split_info.to_proto()
Exemplo n.º 6
0
def dummy_builder(tmp_path_factory):
    """Dummy dataset shared accross tests."""
    data_dir = tmp_path_factory.mktemp('datasets')
    builder = testing.DummyDataset(data_dir=data_dir)
    builder.download_and_prepare()
    yield builder