예제 #1
0
def test_raw_list_text_dset_loader(source_list, target_sentences,
                                   num_samples_per_bucket, expected_source_0,
                                   expected_target_0, expected_label_0):
    # Test Init object
    buckets = sockeye.data_io.define_parallel_buckets(4, 4, 1, 1.0)
    dset_loader = data_io.RawListTextDatasetLoader(buckets=buckets,
                                                   eos_id=10,
                                                   pad_id=C.PAD_ID)

    assert isinstance(dset_loader, data_io.RawListTextDatasetLoader)
    assert len(dset_loader.buckets) == 3

    # Test Load data
    pop_dset_loader = dset_loader.load(source_list, target_sentences,
                                       num_samples_per_bucket)

    assert isinstance(pop_dset_loader, sockeye.data_io.ParallelDataSet)
    assert len(pop_dset_loader.source) == 3
    assert len(pop_dset_loader.target) == 3
    assert len(pop_dset_loader.label) == 3
    np.testing.assert_equal(pop_dset_loader.source[0], expected_source_0)
    np.testing.assert_almost_equal(pop_dset_loader.target[0].asnumpy(),
                                   expected_target_0)
    np.testing.assert_almost_equal(pop_dset_loader.label[0].asnumpy(),
                                   expected_label_0)
예제 #2
0
def test_image_text_sample_iter(source_list, target_sentences,
                                num_samples_per_bucket):
    batch_size = 2
    image_size = _CNN_INPUT_IMAGE_SHAPE
    buckets = sockeye.data_io.define_parallel_buckets(4, 4, 1, 1.0)
    bucket_batch_sizes = sockeye.data_io.define_bucket_batch_sizes(
        buckets,
        batch_size,
        batch_by_words=False,
        batch_num_devices=1,
        data_target_average_len=[None] * len(buckets))
    dset_loader = data_io.RawListTextDatasetLoader(buckets=buckets,
                                                   eos_id=-1,
                                                   pad_id=C.PAD_ID)
    with TemporaryDirectory() as work_dir:
        source_list_img = []
        source_list_npy = []
        for s in source_list:
            source_list_img.append(os.path.join(work_dir, s + ".jpg"))
            source_list_npy.append(os.path.join(work_dir, s + ".npy"))
        # Create random images/features
        for s in source_list_img:
            filename = os.path.join(work_dir, s)
            generate_img_or_feat(filename, use_features=False)
        for s in source_list_npy:
            filename = os.path.join(work_dir, s)
            generate_img_or_feat(filename, use_features=True)

        # Test image iterator
        pop_dset_loader = dset_loader.load(source_list_img, target_sentences,
                                           num_samples_per_bucket)
        data_iter = data_io.ImageTextSampleIter(pop_dset_loader,
                                                buckets,
                                                batch_size,
                                                bucket_batch_sizes,
                                                image_size,
                                                use_feature_loader=False,
                                                preload_features=False)
        data = data_iter.next()
        assert isinstance(data, mx.io.DataBatch)
        np.testing.assert_equal(data.data[0].asnumpy().shape[1:], image_size)

        # Test iterator feature loader + preload all to memory
        pop_dset_loader = dset_loader.load(source_list_npy, target_sentences,
                                           num_samples_per_bucket)
        data_iter = data_io.ImageTextSampleIter(pop_dset_loader,
                                                buckets,
                                                batch_size,
                                                bucket_batch_sizes,
                                                _FEATURE_SHAPE,
                                                use_feature_loader=True,
                                                preload_features=True)
        data = data_iter.next()
        assert isinstance(data, mx.io.DataBatch)
        np.testing.assert_equal(data.data[0].asnumpy().shape[1:],
                                _FEATURE_SHAPE)