Exemplo n.º 1
0
def test_other_set_producer():
    # Create some fake data.
    num = 21
    image_archive, filenames = create_fake_jpeg_tar(seed=1979,
                                                    min_num_images=num,
                                                    max_num_images=num)
    patches = create_fake_patch_images(filenames=filenames,
                                       num_train=7, num_valid=7, num_test=7)

    valid_patches = extract_patch_images(io.BytesIO(patches), 'valid')
    test_patches = extract_patch_images(io.BytesIO(patches), 'test')
    assert len(valid_patches) == 7
    assert len(test_patches) == 7

    groundtruth = numpy.random.RandomState(1979).random_integers(0, 50,
                                                                 size=num)
    assert len(groundtruth) == 21
    gt_lookup = dict(zip(sorted(filenames), groundtruth))
    assert len(gt_lookup) == 21

    def check(which_set, set_patches):
        # Run other_set_producer and push to a fake socket.
        socket = MockSocket(zmq.PUSH)
        other_set_producer(socket, which_set, io.BytesIO(image_archive),
                           io.BytesIO(patches), groundtruth)

        # Now verify the data that socket received.
        with tarfile.open(fileobj=io.BytesIO(image_archive)) as tar:
            num_patched = 0
            for im_fn in filenames:
                # Verify the label and flags of the first (metadata)
                # message.
                label = gt_lookup[im_fn]
                metadata_msg = socket.sent.popleft()
                assert metadata_msg['type'] == 'send_pyobj'
                assert metadata_msg['flags'] == zmq.SNDMORE
                assert metadata_msg['obj'] == (im_fn, label)
                # Verify that the second (data) message came from
                # the right place, either a patch file or a TAR.
                data_msg = socket.sent.popleft()
                assert data_msg['type'] == 'send'
                assert data_msg['flags'] == 0
                expected, patched = load_from_tar_or_patch(tar, im_fn,
                                                           set_patches)
                num_patched += int(patched)
                assert data_msg['data'] == expected
            assert num_patched == len(set_patches)

    check('valid', valid_patches)
    check('test', test_patches)
Exemplo n.º 2
0
def test_load_from_tar_or_patch():
    # Setup fake tar files.
    images, all_filenames = create_fake_jpeg_tar(3, min_num_images=200,
                                                 max_num_images=200,
                                                 gzip_probability=0.0)
    patch_data = create_fake_patch_images(all_filenames[::4], num_train=50,
                                          num_valid=0, num_test=0)

    patches = extract_patch_images(io.BytesIO(patch_data), 'train')

    assert len(patches) == 50
    with tarfile.open(fileobj=io.BytesIO(images)) as tar:
        for fn in all_filenames:
            image, patched = load_from_tar_or_patch(tar, fn, patches)
            if fn in patches:
                assert image == patches[fn]
                assert patched
            else:
                tar_image = tar.extractfile(fn).read()
                assert image == tar_image
                assert not patched
Exemplo n.º 3
0
def test_train_set_producer():
    tar_data, names, jpeg_names = create_fake_tar_of_tars(20150923, 5,
                                                          min_num_images=45,
                                                          max_num_images=55)
    all_jpegs = numpy.array(sum(jpeg_names, []))
    numpy.random.RandomState(20150923).shuffle(all_jpegs)
    patched_files = all_jpegs[:10]
    patches_data = create_fake_patch_images(filenames=patched_files,
                                            num_train=10, num_valid=0,
                                            num_test=0)
    train_patches = extract_patch_images(io.BytesIO(patches_data), 'train')
    socket = MockSocket(zmq.PUSH)
    wnid_map = dict(zip((n.split('.')[0] for n in names), range(len(names))))

    train_set_producer(socket, io.BytesIO(tar_data), io.BytesIO(patches_data),
                       wnid_map)
    tar_data, names, jpeg_names = create_fake_tar_of_tars(20150923, 5,
                                                          min_num_images=45,
                                                          max_num_images=55)
    for tar_name in names:
        with tarfile.open(fileobj=io.BytesIO(tar_data)) as outer_tar:
            with tarfile.open(fileobj=outer_tar.extractfile(tar_name)) as tar:
                for record in tar:
                    jpeg = record.name
                    metadata_msg = socket.sent.popleft()
                    assert metadata_msg['type'] == 'send_pyobj'
                    assert metadata_msg['flags'] == zmq.SNDMORE
                    key = tar_name.split('.')[0]
                    assert metadata_msg['obj'] == (jpeg, wnid_map[key])

                    image_msg = socket.sent.popleft()
                    assert image_msg['type'] == 'send'
                    assert image_msg['flags'] == 0
                    if jpeg in train_patches:
                        assert image_msg['data'] == train_patches[jpeg]
                    else:
                        image_data, _ = load_from_tar_or_patch(tar, jpeg,
                                                               train_patches)
                        assert image_msg['data'] == image_data
Exemplo n.º 4
0
def test_extract_patch_images():
    tar = create_fake_patch_images()
    assert len(extract_patch_images(io.BytesIO(tar), 'train')) == 14
    assert len(extract_patch_images(io.BytesIO(tar), 'valid')) == 15
    assert len(extract_patch_images(io.BytesIO(tar), 'test')) == 21