示例#1
0
    def test_set_data(self):
        data_list1 = list(range(10))

        transform = Compose([
            Lambda(func=lambda x: np.array([x * 10])),
            RandLambda(func=lambda x: x + 1)
        ])

        dataset = CacheDataset(
            data=data_list1,
            transform=transform,
            cache_rate=1.0,
            num_workers=4,
            progress=True,
            copy_cache=not sys.platform == "linux",
        )

        num_workers = 2 if sys.platform == "linux" else 0
        dataloader = DataLoader(dataset=dataset,
                                num_workers=num_workers,
                                batch_size=1)
        for i, d in enumerate(dataloader):
            np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d)
        # simulate another epoch, the cache content should not be modified
        for i, d in enumerate(dataloader):
            np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d)

        # update the datalist and fill the cache content
        data_list2 = list(range(-10, 0))
        dataset.set_data(data=data_list2)
        # rerun with updated cache content
        for i, d in enumerate(dataloader):
            np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d)
示例#2
0
    def test_hash_as_key(self, transform, expected_shape):
        test_image = nib.Nifti1Image(
            np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            test_data = []
            for i in ["1", "2", "2", "3", "3"]:
                for k in ["image", "label", "extra"]:
                    nib.save(test_image, os.path.join(tempdir,
                                                      f"{k}{i}.nii.gz"))
                test_data.append({
                    k: os.path.join(tempdir, f"{k}{i}.nii.gz")
                    for k in ["image", "label", "extra"]
                })

            dataset = CacheDataset(data=test_data,
                                   transform=transform,
                                   cache_num=4,
                                   num_workers=2,
                                   hash_as_key=True)
            self.assertEqual(len(dataset), 5)
            # ensure no duplicated cache content
            self.assertEqual(len(dataset._cache), 3)
            self.assertEqual(dataset.cache_num, 3)
            data1 = dataset[0]
            data2 = dataset[1]
            data3 = dataset[-1]
            # test slice indices
            data4 = dataset[0:-1]
            self.assertEqual(len(data4), 4)

            if transform is None:
                self.assertEqual(data1["image"],
                                 os.path.join(tempdir, "image1.nii.gz"))
                self.assertEqual(data2["label"],
                                 os.path.join(tempdir, "label2.nii.gz"))
                self.assertEqual(data3["image"],
                                 os.path.join(tempdir, "image3.nii.gz"))
            else:
                self.assertTupleEqual(data1["image"].shape, expected_shape)
                self.assertTupleEqual(data2["label"].shape, expected_shape)
                self.assertTupleEqual(data3["image"].shape, expected_shape)
                for d in data4:
                    self.assertTupleEqual(d["image"].shape, expected_shape)

            test_data2 = test_data[:3]
            dataset.set_data(data=test_data2)
            self.assertEqual(len(dataset), 3)
            # ensure no duplicated cache content
            self.assertEqual(len(dataset._cache), 2)
            self.assertEqual(dataset.cache_num, 2)