def test_raw_list_text_dset_loader(source_list, target_sentences, num_samples_per_bucket, expected_source_0, expected_target_0, expected_label_0): # Test Init object buckets = sockeye.data_io.define_parallel_buckets(4, 4, 1, 1.0) dset_loader = data_io.RawListTextDatasetLoader(buckets=buckets, eos_id=10, pad_id=C.PAD_ID) assert isinstance(dset_loader, data_io.RawListTextDatasetLoader) assert len(dset_loader.buckets) == 3 # Test Load data pop_dset_loader = dset_loader.load(source_list, target_sentences, num_samples_per_bucket) assert isinstance(pop_dset_loader, sockeye.data_io.ParallelDataSet) assert len(pop_dset_loader.source) == 3 assert len(pop_dset_loader.target) == 3 assert len(pop_dset_loader.label) == 3 np.testing.assert_equal(pop_dset_loader.source[0], expected_source_0) np.testing.assert_almost_equal(pop_dset_loader.target[0].asnumpy(), expected_target_0) np.testing.assert_almost_equal(pop_dset_loader.label[0].asnumpy(), expected_label_0)
def test_image_text_sample_iter(source_list, target_sentences, num_samples_per_bucket): batch_size = 2 image_size = _CNN_INPUT_IMAGE_SHAPE buckets = sockeye.data_io.define_parallel_buckets(4, 4, 1, 1.0) bucket_batch_sizes = sockeye.data_io.define_bucket_batch_sizes( buckets, batch_size, batch_by_words=False, batch_num_devices=1, data_target_average_len=[None] * len(buckets)) dset_loader = data_io.RawListTextDatasetLoader(buckets=buckets, eos_id=-1, pad_id=C.PAD_ID) with TemporaryDirectory() as work_dir: source_list_img = [] source_list_npy = [] for s in source_list: source_list_img.append(os.path.join(work_dir, s + ".jpg")) source_list_npy.append(os.path.join(work_dir, s + ".npy")) # Create random images/features for s in source_list_img: filename = os.path.join(work_dir, s) generate_img_or_feat(filename, use_features=False) for s in source_list_npy: filename = os.path.join(work_dir, s) generate_img_or_feat(filename, use_features=True) # Test image iterator pop_dset_loader = dset_loader.load(source_list_img, target_sentences, num_samples_per_bucket) data_iter = data_io.ImageTextSampleIter(pop_dset_loader, buckets, batch_size, bucket_batch_sizes, image_size, use_feature_loader=False, preload_features=False) data = data_iter.next() assert isinstance(data, mx.io.DataBatch) np.testing.assert_equal(data.data[0].asnumpy().shape[1:], image_size) # Test iterator feature loader + preload all to memory pop_dset_loader = dset_loader.load(source_list_npy, target_sentences, num_samples_per_bucket) data_iter = data_io.ImageTextSampleIter(pop_dset_loader, buckets, batch_size, bucket_batch_sizes, _FEATURE_SHAPE, use_feature_loader=True, preload_features=True) data = data_iter.next() assert isinstance(data, mx.io.DataBatch) np.testing.assert_equal(data.data[0].asnumpy().shape[1:], _FEATURE_SHAPE)