Пример #1
0
def test_write_read_float_labels(csv_of_volumes, tmp_path):  # noqa: F811
    files = io.read_csv(csv_of_volumes, skip_header=False)
    files = [(x, random.random()) for x, _ in files]
    filename_template = str(tmp_path / "data-{shard:03d}.tfrecords")
    examples_per_shard = 12
    tfrecord.write(
        files,
        filename_template=filename_template,
        examples_per_shard=examples_per_shard,
        processes=1,
    )

    paths = list(tmp_path.glob("data-*.tfrecords"))
    paths = sorted(paths)
    assert len(paths) == 9
    assert (tmp_path / "data-008.tfrecords").is_file()

    dset = tf.data.TFRecordDataset(list(map(str, paths)),
                                   compression_type="GZIP")
    dset = dset.map(
        tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=True))

    for ref, test in zip(files, dset):
        x, y = ref
        x = io.read_volume(x)
        assert_array_equal(x, test[0])
        assert_array_equal(y, test[1])
Пример #2
0
def tfrecord_dataset(
    file_pattern,
    volume_shape,
    shuffle,
    scalar_label,
    compressed=True,
    num_parallel_calls=AUTOTUNE,
):
    """Return `tf.data.Dataset` from TFRecord files."""
    dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
    # Read each of these files as a TFRecordDataset.
    # Assume all files have same compression type as the first file.
    compression_type = "GZIP" if compressed else None
    cycle_length = 1 if num_parallel_calls is None else num_parallel_calls
    dataset = dataset.interleave(
        map_func=lambda x: tf.data.TFRecordDataset(
            x, compression_type=compression_type),
        cycle_length=cycle_length,
        num_parallel_calls=num_parallel_calls,
    )
    parse_fn = parse_example_fn(volume_shape=volume_shape,
                                scalar_label=scalar_label)
    dataset = dataset.map(map_func=parse_fn,
                          num_parallel_calls=num_parallel_calls)
    return dataset
Пример #3
0
def test_write_read_volume_labels_all_processes(csv_of_volumes,
                                                tmp_path):  # noqa: F811
    files = io.read_csv(csv_of_volumes, skip_header=False)
    filename_template = str(tmp_path / "data-{shard:03d}.tfrecords")
    examples_per_shard = 12
    tfrecord.write(
        files,
        filename_template=filename_template,
        examples_per_shard=examples_per_shard,
        processes=None,
    )

    paths = list(tmp_path.glob("data-*.tfrecords"))
    paths = sorted(paths)
    assert len(paths) == 9
    assert (tmp_path / "data-008.tfrecords").is_file()

    dset = tf.data.TFRecordDataset(list(map(str, paths)),
                                   compression_type="GZIP")
    dset = dset.map(
        tfrecord.parse_example_fn(volume_shape=(8, 8, 8), scalar_label=False))

    for ref, test in zip(files, dset):
        x, y = ref
        x, y = io.read_volume(x), io.read_volume(y)
        assert_array_equal(x, test[0])
        assert_array_equal(y, test[1])

    with pytest.raises(ValueError):
        tfrecord.write(files,
                       filename_template="data/foobar-{}.tfrecords",
                       examples_per_shard=4)