Exemple #1
0
def test_base64_image_deserializer(tmpdir):
    import io, base64, uuid
    from PIL import Image
    images, b64_images = [], []

    np.random.seed(1)
    for i in range(10):
        data = np.random.randint(0, 2**8, (5, 7, 3))
        image = Image.fromarray(data.astype('uint8'), "RGB")
        buf = io.BytesIO()
        image.save(buf, format='PNG')
        assert image.width == 7 and image.height == 5
        b64_images.append(base64.b64encode(buf.getvalue()))
        images.append(np.array(image))

    image_data = str(tmpdir / 'mbdata1.txt')
    seq_ids = []
    uid = uuid.uuid1().int >> 64
    with open(image_data, 'wb') as f:
        for i, data in enumerate(b64_images):
            seq_id = uid ^ i
            seq_id = str(seq_id).encode('ascii')
            seq_ids.append(seq_id)
            line = seq_id + b'\t'
            label = str(i).encode('ascii')
            line += label + b'\t' + data + b'\n'
            f.write(line)

    ctf_data = str(tmpdir / 'mbdata2.txt')
    with open(ctf_data, 'wb') as f:
        for i, sid in enumerate(seq_ids):
            line = sid + b'\t' + b'|index ' + str(i).encode('ascii') + b'\n'
            f.write(line)

    transforms = [xforms.scale(width=7, height=5, channels=3)]
    b64_deserializer = Base64ImageDeserializer(
        image_data,
        StreamDefs(images=StreamDef(field='image', transforms=transforms),
                   labels=StreamDef(field='label', shape=10)))

    ctf_deserializer = CTFDeserializer(
        ctf_data, StreamDefs(index=StreamDef(field='index', shape=1)))

    mb_source = MinibatchSource([ctf_deserializer, b64_deserializer])
    assert isinstance(mb_source, MinibatchSource)

    for j in range(100):
        mb = mb_source.next_minibatch(10)

        index_stream = mb_source.streams['index']
        index = mb[index_stream].asarray().flatten()
        image_stream = mb_source.streams['images']

        results = mb[image_stream].asarray()

        for i in range(10):
            # original images are RBG, openCV produces BGR images,
            # reverse the last dimension of the original images
            bgrImage = images[int(index[i])][:, :, ::-1]
            assert (bgrImage == results[i][0]).all()
Exemple #2
0
def test_base64_is_equal_image(tmpdir):
    import io, base64
    from PIL import Image
    np.random.seed(1)

    file_mapping_path = str(tmpdir / 'file_mapping.txt')
    base64_mapping_path = str(tmpdir / 'base64_mapping.txt')

    with open(file_mapping_path, 'w') as file_mapping:
        with open(base64_mapping_path, 'w') as base64_mapping:
            for i in range(10):
                data = np.random.randint(0, 2**8, (5, 7, 3))
                image = Image.fromarray(data.astype('uint8'), "RGB")
                buf = io.BytesIO()
                image.save(buf, format='PNG')
                assert image.width == 7 and image.height == 5

                label = str(i)
                # save to base 64 mapping file
                encoded = base64.b64encode(buf.getvalue()).decode('ascii')
                base64_mapping.write('%s\t%s\n' % (label, encoded))

                # save to mapping + png file
                file_name = label + '.png'
                with open(str(tmpdir / file_name), 'wb') as f:
                    f.write(buf.getvalue())
                file_mapping.write('.../%s\t%s\n' % (file_name, label))

    transforms = [xforms.scale(width=7, height=5, channels=3)]
    b64_deserializer = Base64ImageDeserializer(
        base64_mapping_path,
        StreamDefs(images1=StreamDef(field='image', transforms=transforms),
                   labels1=StreamDef(field='label', shape=10)))

    file_image_deserializer = ImageDeserializer(
        file_mapping_path,
        StreamDefs(images2=StreamDef(field='image', transforms=transforms),
                   labels2=StreamDef(field='label', shape=10)))

    mb_source = MinibatchSource([b64_deserializer, file_image_deserializer])
    for j in range(20):
        mb = mb_source.next_minibatch(1)

        images1_stream = mb_source.streams['images1']
        images1 = mb[images1_stream].asarray()
        images2_stream = mb_source.streams['images2']
        images2 = mb[images2_stream].asarray()
        assert (images1 == images2).all()