コード例 #1
0
    def test_dataset_2d(self):
        from torch_em.data import SegmentationDataset
        raw_key, label_key = "raw", "labels"
        shape, _ = self.create_default_data(raw_key, label_key)

        patch_shape = (1, 32, 32)
        ds = SegmentationDataset(self.path, raw_key, self.path, label_key, patch_shape=patch_shape, ndim=2)
        self.assertEqual(ds.raw.shape, shape)
        self.assertEqual(ds.labels.shape, shape)
        self.assertEqual(ds._ndim, 2)

        expected_shape = patch_shape
        for i in range(10):
            x, y = ds[i]
            self.assertEqual(x.shape, expected_shape)
            self.assertEqual(y.shape, expected_shape)
コード例 #2
0
    def test_wrap_dataset(self):
        from torch_em.data import DatasetWrapper

        patch_shape = (32, 32, 32)
        ds = SegmentationDataset(self.path,
                                 self.raw_key,
                                 self.path,
                                 self.label_key,
                                 patch_shape=patch_shape)
        wrapped_ds = DatasetWrapper(
            ds, lambda xy: (xy[0][:, :10, :10, :10], xy[1][:, :20, :20, :20]))

        expected_shape_x = (1, 10, 10, 10)
        expected_shape_y = (1, 20, 20, 20)
        for i in range(3):
            x, y = wrapped_ds[i]
            self.assertEqual(x.shape, expected_shape_x)
            self.assertEqual(y.shape, expected_shape_y)
コード例 #3
0
    def test_roi(self):
        from torch_em.data import SegmentationDataset
        raw_key, label_key = "raw", "labels"
        self.create_default_data(raw_key, label_key)

        patch_shape = (32, 32, 32)
        roi = np.s_[32:96, 32:96, 32:96]
        ds = SegmentationDataset(self.path, raw_key, self.path, label_key, patch_shape=patch_shape, roi=roi)

        roi_shape = 3 * (64,)
        self.assertEqual(ds.raw.shape, roi_shape)
        self.assertEqual(ds.labels.shape, roi_shape)

        expected_shape = (1,) + patch_shape
        for i in range(10):
            x, y = ds[i]
            self.assertEqual(x.shape, expected_shape)
            self.assertEqual(y.shape, expected_shape)
コード例 #4
0
    def test_dataset_3d4d(self):
        from torch_em.data import SegmentationDataset
        raw_key, label_key = "raw", "labels"
        shape = (4, 128, 128, 128)
        chunks = (1, 32, 32, 32)
        create_segmentation_test_data(self.path, raw_key, label_key, shape=shape, chunks=chunks)

        patch_shape = (1, 32, 32, 32)
        ds = SegmentationDataset(self.path, raw_key, self.path, label_key, ndim=3, patch_shape=patch_shape)
        self.assertEqual(ds.raw.shape, shape)
        self.assertEqual(ds.labels.shape, shape)
        self.assertEqual(ds._ndim, 3)

        expected_shape = patch_shape
        for i in range(10):
            x, y = ds[i]
            self.assertEqual(x.shape, expected_shape)
            self.assertEqual(y.shape, expected_shape)
コード例 #5
0
    def test_with_label_channels(self):
        from torch_em.data import SegmentationDataset

        raw_key, label_key = "raw", "labels"
        shape = (3, 128, 128, 128)
        chunks = (1, 32, 32, 32)
        with h5py.File(self.path, "a") as f:
            f.create_dataset(label_key, data=np.random.rand(*shape), chunks=chunks)
            f.create_dataset(raw_key, data=np.random.rand(*shape[1:]), chunks=chunks[1:])

        patch_shape = (32, 32, 32)
        ds = SegmentationDataset(
            self.path, raw_key, self.path, label_key, patch_shape=patch_shape, with_label_channels=True
        )
        self.assertEqual(ds._ndim, 3)
        expected_raw_shape = (1,) + patch_shape
        expected_label_shape = (3,) + patch_shape
        for i in range(10):
            x, y = ds[i]
            self.assertEqual(x.shape, expected_raw_shape)
            self.assertEqual(y.shape, expected_label_shape)