def test_set_data(self): data_list1 = list(range(10)) transform = Compose([ Lambda(func=lambda x: np.array([x * 10])), RandLambda(func=lambda x: x + 1) ]) dataset = CacheDataset( data=data_list1, transform=transform, cache_rate=1.0, num_workers=4, progress=True, copy_cache=not sys.platform == "linux", ) num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) # simulate another epoch, the cache content should not be modified for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) dataset.set_data(data=data_list2) # rerun with updated cache content for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d)
def test_hash_as_key(self, transform, expected_shape): test_image = nib.Nifti1Image( np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: test_data = [] for i in ["1", "2", "2", "3", "3"]: for k in ["image", "label", "extra"]: nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) test_data.append({ k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"] }) dataset = CacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2, hash_as_key=True) self.assertEqual(len(dataset), 5) # ensure no duplicated cache content self.assertEqual(len(dataset._cache), 3) self.assertEqual(dataset.cache_num, 3) data1 = dataset[0] data2 = dataset[1] data3 = dataset[-1] # test slice indices data4 = dataset[0:-1] self.assertEqual(len(data4), 4) if transform is None: self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) self.assertEqual(data3["image"], os.path.join(tempdir, "image3.nii.gz")) else: self.assertTupleEqual(data1["image"].shape, expected_shape) self.assertTupleEqual(data2["label"].shape, expected_shape) self.assertTupleEqual(data3["image"].shape, expected_shape) for d in data4: self.assertTupleEqual(d["image"].shape, expected_shape) test_data2 = test_data[:3] dataset.set_data(data=test_data2) self.assertEqual(len(dataset), 3) # ensure no duplicated cache content self.assertEqual(len(dataset._cache), 2) self.assertEqual(dataset.cache_num, 2)