Exemplo n.º 1
0
    def test_shape(self, transform, expected_shape, kwargs=None):
        kwargs = kwargs or {}
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz"))
            test_data = [
                {
                    "image": os.path.join(tempdir, "test_image1.nii.gz"),
                    "label": os.path.join(tempdir, "test_label1.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra1.nii.gz"),
                },
                {
                    "image": os.path.join(tempdir, "test_image2.nii.gz"),
                    "label": os.path.join(tempdir, "test_label2.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra2.nii.gz"),
                },
            ]

            cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
            dataset_precached = LMDBDataset(
                data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs
            )
            data1_precached = dataset_precached[0]
            data2_precached = dataset_precached[1]

            dataset_postcached = LMDBDataset(
                data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs
            )
            data1_postcached = dataset_postcached[0]
            data2_postcached = dataset_postcached[1]

        if transform is None:
            self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
            self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz"))
            self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
            self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz"))
        else:
            self.assertTupleEqual(data1_precached["image"].shape, expected_shape)
            self.assertTupleEqual(data1_precached["label"].shape, expected_shape)
            self.assertTupleEqual(data1_precached["extra"].shape, expected_shape)
            self.assertTupleEqual(data2_precached["image"].shape, expected_shape)
            self.assertTupleEqual(data2_precached["label"].shape, expected_shape)
            self.assertTupleEqual(data2_precached["extra"].shape, expected_shape)

            self.assertTupleEqual(data1_postcached["image"].shape, expected_shape)
            self.assertTupleEqual(data1_postcached["label"].shape, expected_shape)
            self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape)
            self.assertTupleEqual(data2_postcached["image"].shape, expected_shape)
            self.assertTupleEqual(data2_postcached["label"].shape, expected_shape)
            self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)
    def test_cache(self):
        """testing no inplace change to the hashed item"""
        items = [[list(range(i))] for i in range(5)]

        with tempfile.TemporaryDirectory() as tempdir:
            ds = LMDBDataset(items, transform=_InplaceXform(), cache_dir=tempdir, lmdb_kwargs={"map_size": 10 * 1024})
            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
            ds1 = LMDBDataset(items, transform=_InplaceXform(), cache_dir=tempdir, lmdb_kwargs={"map_size": 10 * 1024})
            self.assertEqual(list(ds1), list(ds))
            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])

            ds = LMDBDataset(
                items,
                transform=_InplaceXform(),
                cache_dir=tempdir,
                lmdb_kwargs={"map_size": 10 * 1024},
                hash_func=json_hashing,
            )
            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
            ds1 = LMDBDataset(
                items,
                transform=_InplaceXform(),
                cache_dir=tempdir,
                lmdb_kwargs={"map_size": 10 * 1024},
                hash_func=json_hashing,
            )
            self.assertEqual(list(ds1), list(ds))
            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])

        self.assertTrue(isinstance(ds1.info(), dict))
    def test_mp_cache(self):
        items = [[list(range(i))] for i in range(5)]

        ds = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={"map_size": 10 * 1024})
        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
        ds1 = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={"map_size": 10 * 1024})
        self.assertEqual(list(ds1), list(ds))
        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])

        ds = LMDBDataset(
            items,
            transform=_InplaceXform(),
            cache_dir=self.tempdir,
            lmdb_kwargs={"map_size": 10 * 1024},
            hash_func=json_hashing,
        )
        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
        ds1 = LMDBDataset(
            items,
            transform=_InplaceXform(),
            cache_dir=self.tempdir,
            lmdb_kwargs={"map_size": 10 * 1024},
            hash_func=json_hashing,
        )
        self.assertEqual(list(ds1), list(ds))
        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])

        self.assertTrue(isinstance(ds1.info(), dict))
    def test_shape(self, transform, expected_shape, kwargs=None):
        kwargs = kwargs or {}
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz"))
            test_data = [
                {
                    "image": os.path.join(tempdir, "test_image1.nii.gz"),
                    "label": os.path.join(tempdir, "test_label1.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra1.nii.gz"),
                },
                {
                    "image": os.path.join(tempdir, "test_image2.nii.gz"),
                    "label": os.path.join(tempdir, "test_label2.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra2.nii.gz"),
                },
            ]

            cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
            dataset_precached = LMDBDataset(
                data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs
            )
            data1_precached = dataset_precached[0]
            data2_precached = dataset_precached[1]

            dataset_postcached = LMDBDataset(
                data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs
            )
            data1_postcached = dataset_postcached[0]
            data2_postcached = dataset_postcached[1]

            if transform is None:
                self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
                self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz"))
                self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
                self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz"))
            else:
                self.assertTupleEqual(data1_precached["image"].shape, expected_shape)
                self.assertTupleEqual(data1_precached["label"].shape, expected_shape)
                self.assertTupleEqual(data1_precached["extra"].shape, expected_shape)
                self.assertTupleEqual(data2_precached["image"].shape, expected_shape)
                self.assertTupleEqual(data2_precached["label"].shape, expected_shape)
                self.assertTupleEqual(data2_precached["extra"].shape, expected_shape)

                self.assertTupleEqual(data1_postcached["image"].shape, expected_shape)
                self.assertTupleEqual(data1_postcached["label"].shape, expected_shape)
                self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape)
                self.assertTupleEqual(data2_postcached["image"].shape, expected_shape)
                self.assertTupleEqual(data2_postcached["label"].shape, expected_shape)
                self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)

            # update the data to cache
            test_data_new = [
                {
                    "image": os.path.join(tempdir, "test_image1_new.nii.gz"),
                    "label": os.path.join(tempdir, "test_label1_new.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"),
                },
                {
                    "image": os.path.join(tempdir, "test_image2_new.nii.gz"),
                    "label": os.path.join(tempdir, "test_label2_new.nii.gz"),
                    "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"),
                },
            ]
            # test new exchanged cache content
            if transform is None:
                dataset_postcached.set_data(data=test_data_new)
                self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz"))
                self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz"))
                self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz"))
            else:
                with self.assertRaises(RuntimeError):
                    dataset_postcached.set_data(data=test_data_new)  # filename list updated, files do not exist