def test_group_by_keys(): stream = open("testdata/imagenet-000000.tgz", mode='rb') data = tarrecords.tardata(stream) data = tarrecords.group_by_keys()(data) samples = list(data) keys = list(samples[0].keys()) assert 'png' in keys assert 'cls' in keys
def test_tardata(): stream = open("testdata/imagenet-000000.tgz", mode='rb') data = tarrecords.tardata(stream) samples = list(data) assert samples[0] == ('10.cls', b'304'), samples[0] assert {2} == set([len(x) for x in samples])