def test_write_read_float_labels(csv_of_volumes, tmp_path): # noqa: F811 files = io.read_csv(csv_of_volumes, skip_header=False) files = [(x, random.random()) for x, _ in files] filename_template = str(tmp_path / "data-{shard:03d}.tfrecords") examples_per_shard = 12 tfrecord.write( files, filename_template=filename_template, examples_per_shard=examples_per_shard, processes=1, ) paths = list(tmp_path.glob("data-*.tfrecords")) paths = sorted(paths) assert len(paths) == 9 assert (tmp_path / "data-008.tfrecords").is_file() dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") dset = dset.map( tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=True)) for ref, test in zip(files, dset): x, y = ref x = io.read_volume(x) assert_array_equal(x, test[0]) assert_array_equal(y, test[1])
def tfrecord_dataset( file_pattern, volume_shape, shuffle, scalar_label, compressed=True, num_parallel_calls=AUTOTUNE, ): """Return `tf.data.Dataset` from TFRecord files.""" dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle) # Read each of these files as a TFRecordDataset. # Assume all files have same compression type as the first file. compression_type = "GZIP" if compressed else None cycle_length = 1 if num_parallel_calls is None else num_parallel_calls dataset = dataset.interleave( map_func=lambda x: tf.data.TFRecordDataset( x, compression_type=compression_type), cycle_length=cycle_length, num_parallel_calls=num_parallel_calls, ) parse_fn = parse_example_fn(volume_shape=volume_shape, scalar_label=scalar_label) dataset = dataset.map(map_func=parse_fn, num_parallel_calls=num_parallel_calls) return dataset
def test_write_read_volume_labels_all_processes(csv_of_volumes, tmp_path): # noqa: F811 files = io.read_csv(csv_of_volumes, skip_header=False) filename_template = str(tmp_path / "data-{shard:03d}.tfrecords") examples_per_shard = 12 tfrecord.write( files, filename_template=filename_template, examples_per_shard=examples_per_shard, processes=None, ) paths = list(tmp_path.glob("data-*.tfrecords")) paths = sorted(paths) assert len(paths) == 9 assert (tmp_path / "data-008.tfrecords").is_file() dset = tf.data.TFRecordDataset(list(map(str, paths)), compression_type="GZIP") dset = dset.map( tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=False)) for ref, test in zip(files, dset): x, y = ref x, y = io.read_volume(x), io.read_volume(y) assert_array_equal(x, test[0]) assert_array_equal(y, test[1]) with pytest.raises(ValueError): tfrecord.write(files, filename_template="data/foobar-{}.tfrecords", examples_per_shard=4)