def test_other_set_producer(): # Create some fake data. num = 21 image_archive, filenames = create_fake_jpeg_tar(seed=1979, min_num_images=num, max_num_images=num) patches = create_fake_patch_images(filenames=filenames, num_train=7, num_valid=7, num_test=7) valid_patches = extract_patch_images(io.BytesIO(patches), 'valid') test_patches = extract_patch_images(io.BytesIO(patches), 'test') assert len(valid_patches) == 7 assert len(test_patches) == 7 groundtruth = numpy.random.RandomState(1979).random_integers(0, 50, size=num) assert len(groundtruth) == 21 gt_lookup = dict(zip(sorted(filenames), groundtruth)) assert len(gt_lookup) == 21 def check(which_set, set_patches): # Run other_set_producer and push to a fake socket. socket = MockSocket(zmq.PUSH) other_set_producer(socket, which_set, io.BytesIO(image_archive), io.BytesIO(patches), groundtruth) # Now verify the data that socket received. with tarfile.open(fileobj=io.BytesIO(image_archive)) as tar: num_patched = 0 for im_fn in filenames: # Verify the label and flags of the first (metadata) # message. label = gt_lookup[im_fn] metadata_msg = socket.sent.popleft() assert metadata_msg['type'] == 'send_pyobj' assert metadata_msg['flags'] == zmq.SNDMORE assert metadata_msg['obj'] == (im_fn, label) # Verify that the second (data) message came from # the right place, either a patch file or a TAR. data_msg = socket.sent.popleft() assert data_msg['type'] == 'send' assert data_msg['flags'] == 0 expected, patched = load_from_tar_or_patch(tar, im_fn, set_patches) num_patched += int(patched) assert data_msg['data'] == expected assert num_patched == len(set_patches) check('valid', valid_patches) check('test', test_patches)
def test_load_from_tar_or_patch(): # Setup fake tar files. images, all_filenames = create_fake_jpeg_tar(3, min_num_images=200, max_num_images=200, gzip_probability=0.0) patch_data = create_fake_patch_images(all_filenames[::4], num_train=50, num_valid=0, num_test=0) patches = extract_patch_images(io.BytesIO(patch_data), 'train') assert len(patches) == 50 with tarfile.open(fileobj=io.BytesIO(images)) as tar: for fn in all_filenames: image, patched = load_from_tar_or_patch(tar, fn, patches) if fn in patches: assert image == patches[fn] assert patched else: tar_image = tar.extractfile(fn).read() assert image == tar_image assert not patched
def test_train_set_producer(): tar_data, names, jpeg_names = create_fake_tar_of_tars(20150923, 5, min_num_images=45, max_num_images=55) all_jpegs = numpy.array(sum(jpeg_names, [])) numpy.random.RandomState(20150923).shuffle(all_jpegs) patched_files = all_jpegs[:10] patches_data = create_fake_patch_images(filenames=patched_files, num_train=10, num_valid=0, num_test=0) train_patches = extract_patch_images(io.BytesIO(patches_data), 'train') socket = MockSocket(zmq.PUSH) wnid_map = dict(zip((n.split('.')[0] for n in names), range(len(names)))) train_set_producer(socket, io.BytesIO(tar_data), io.BytesIO(patches_data), wnid_map) tar_data, names, jpeg_names = create_fake_tar_of_tars(20150923, 5, min_num_images=45, max_num_images=55) for tar_name in names: with tarfile.open(fileobj=io.BytesIO(tar_data)) as outer_tar: with tarfile.open(fileobj=outer_tar.extractfile(tar_name)) as tar: for record in tar: jpeg = record.name metadata_msg = socket.sent.popleft() assert metadata_msg['type'] == 'send_pyobj' assert metadata_msg['flags'] == zmq.SNDMORE key = tar_name.split('.')[0] assert metadata_msg['obj'] == (jpeg, wnid_map[key]) image_msg = socket.sent.popleft() assert image_msg['type'] == 'send' assert image_msg['flags'] == 0 if jpeg in train_patches: assert image_msg['data'] == train_patches[jpeg] else: image_data, _ = load_from_tar_or_patch(tar, jpeg, train_patches) assert image_msg['data'] == image_data
def test_extract_patch_images(): tar = create_fake_patch_images() assert len(extract_patch_images(io.BytesIO(tar), 'train')) == 14 assert len(extract_patch_images(io.BytesIO(tar), 'valid')) == 15 assert len(extract_patch_images(io.BytesIO(tar), 'test')) == 21