def test_get_data(): csv_path = nbutils.get_data() assert Path(csv_path).is_file() files = read_csv(csv_path) assert len(files) == 10 assert all(len(r) == 2 for r in files) for x, y in files: assert Path(x).is_file() assert Path(y).is_file()
def test_convert(tmp_path): runner = CliRunner() with runner.isolated_filesystem(): csvpath = get_data(str(tmp_path)) tfrecords_template = Path('data/shard-{shard:03d}.tfrecords') tfrecords_template.parent.mkdir(exist_ok=True) args = """\ convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 --volumes-per-shard=2 --to-ras --no-verify-volumes """.format(csvpath, tfrecords_template) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 assert Path('data/shard-000.tfrecords').is_file() assert Path('data/shard-001.tfrecords').is_file() assert Path('data/shard-002.tfrecords').is_file() assert Path('data/shard-003.tfrecords').is_file() assert Path('data/shard-004.tfrecords').is_file() assert not Path('data/shard-005.tfrecords').is_file()
def test_convert_scalar_float_labels(tmp_path): runner = CliRunner() with runner.isolated_filesystem(): csvpath = get_data(str(tmp_path)) # Make labels scalars. data = [(x, 1.0) for (x, _) in read_csv(csvpath)] csvpath = tmp_path.with_suffix(".new.csv") with open(csvpath, "w", newline="") as myfile: wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) wr.writerows(data) tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") tfrecords_template.parent.mkdir(exist_ok=True) args = """\ convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 --examples-per-shard=2 --to-ras --no-verify-volumes """.format(csvpath, tfrecords_template) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 assert Path("data/shard-000.tfrecords").is_file() assert Path("data/shard-001.tfrecords").is_file() assert Path("data/shard-002.tfrecords").is_file() assert Path("data/shard-003.tfrecords").is_file() assert Path("data/shard-004.tfrecords").is_file() assert not Path("data/shard-005.tfrecords").is_file()