예제 #1
0
def input_index_data(num_docs=None, batch_size=8, dataset_type='f30k'):
    captions = 'dataset_flickr30k.json' if dataset_type == 'f30k' else 'captions.txt'
    if dataset_type == 'toy-data':
        base_folder = '.'
    else:
        base_folder = 'data'
    data_loader = get_data_loader(
        root=os.path.join(cur_dir, f'{base_folder}/{dataset_type}/images'),
        captions=os.path.join(cur_dir,
                              f'{base_folder}/{dataset_type}/{captions}'),
        split='test',
        batch_size=batch_size,
        dataset_type=dataset_type)

    for i, (images, captions) in enumerate(data_loader):
        for image, caption in zip(images, captions):
            hashed = hashlib.sha1(image).hexdigest()
            document_img = Document()

            document_img.buffer = image
            document_img.modality = 'image'
            document_img.mime_type = 'image/jpeg'

            document_caption = Document(id=hashed)

            document_caption.text = caption
            document_caption.modality = 'text'
            document_caption.mime_type = 'text/plain'
            document_caption.tags['id'] = caption

            yield document_img
            yield document_caption

        if num_docs and (i + 1) * batch_size >= num_docs:
            break
예제 #2
0
def test_segment_driver():
    valid_doc = Document()
    valid_doc.text = 'valid'
    valid_doc.mime_type = 'image/png'

    driver = SimpleSegmentDriver()
    executor = MockSegmenter()
    driver.attach(executor=executor, runtime=None)
    driver._apply_all(DocumentSet([valid_doc]))

    assert valid_doc.chunks[0].tags['id'] == 3
    assert valid_doc.chunks[0].parent_id == valid_doc.id
    np.testing.assert_equal(valid_doc.chunks[0].blob, np.array([0.0, 0.0, 0.0]))
    assert valid_doc.chunks[0].weight == 0.0
    assert valid_doc.chunks[0].mime_type == 'text/plain'

    assert valid_doc.chunks[1].tags['id'] == 4
    assert valid_doc.chunks[1].parent_id == valid_doc.id
    np.testing.assert_equal(valid_doc.chunks[1].blob, np.array([1.0, 1.0, 1.0]))
    assert valid_doc.chunks[1].weight == 1.0
    assert valid_doc.chunks[1].mime_type == 'image/png'

    assert valid_doc.chunks[2].tags['id'] == 5
    assert valid_doc.chunks[2].parent_id == valid_doc.id
    np.testing.assert_equal(valid_doc.chunks[2].blob, np.array([2.0, 2.0, 2.0]))
    assert valid_doc.chunks[2].weight == 2.0
    assert valid_doc.chunks[2].mime_type == 'image/png'