def test_dataset_2d(self): from torch_em.data import SegmentationDataset raw_key, label_key = "raw", "labels" shape, _ = self.create_default_data(raw_key, label_key) patch_shape = (1, 32, 32) ds = SegmentationDataset(self.path, raw_key, self.path, label_key, patch_shape=patch_shape, ndim=2) self.assertEqual(ds.raw.shape, shape) self.assertEqual(ds.labels.shape, shape) self.assertEqual(ds._ndim, 2) expected_shape = patch_shape for i in range(10): x, y = ds[i] self.assertEqual(x.shape, expected_shape) self.assertEqual(y.shape, expected_shape)
def test_wrap_dataset(self): from torch_em.data import DatasetWrapper patch_shape = (32, 32, 32) ds = SegmentationDataset(self.path, self.raw_key, self.path, self.label_key, patch_shape=patch_shape) wrapped_ds = DatasetWrapper( ds, lambda xy: (xy[0][:, :10, :10, :10], xy[1][:, :20, :20, :20])) expected_shape_x = (1, 10, 10, 10) expected_shape_y = (1, 20, 20, 20) for i in range(3): x, y = wrapped_ds[i] self.assertEqual(x.shape, expected_shape_x) self.assertEqual(y.shape, expected_shape_y)
def test_roi(self): from torch_em.data import SegmentationDataset raw_key, label_key = "raw", "labels" self.create_default_data(raw_key, label_key) patch_shape = (32, 32, 32) roi = np.s_[32:96, 32:96, 32:96] ds = SegmentationDataset(self.path, raw_key, self.path, label_key, patch_shape=patch_shape, roi=roi) roi_shape = 3 * (64,) self.assertEqual(ds.raw.shape, roi_shape) self.assertEqual(ds.labels.shape, roi_shape) expected_shape = (1,) + patch_shape for i in range(10): x, y = ds[i] self.assertEqual(x.shape, expected_shape) self.assertEqual(y.shape, expected_shape)
def test_dataset_3d4d(self): from torch_em.data import SegmentationDataset raw_key, label_key = "raw", "labels" shape = (4, 128, 128, 128) chunks = (1, 32, 32, 32) create_segmentation_test_data(self.path, raw_key, label_key, shape=shape, chunks=chunks) patch_shape = (1, 32, 32, 32) ds = SegmentationDataset(self.path, raw_key, self.path, label_key, ndim=3, patch_shape=patch_shape) self.assertEqual(ds.raw.shape, shape) self.assertEqual(ds.labels.shape, shape) self.assertEqual(ds._ndim, 3) expected_shape = patch_shape for i in range(10): x, y = ds[i] self.assertEqual(x.shape, expected_shape) self.assertEqual(y.shape, expected_shape)
def test_with_label_channels(self): from torch_em.data import SegmentationDataset raw_key, label_key = "raw", "labels" shape = (3, 128, 128, 128) chunks = (1, 32, 32, 32) with h5py.File(self.path, "a") as f: f.create_dataset(label_key, data=np.random.rand(*shape), chunks=chunks) f.create_dataset(raw_key, data=np.random.rand(*shape[1:]), chunks=chunks[1:]) patch_shape = (32, 32, 32) ds = SegmentationDataset( self.path, raw_key, self.path, label_key, patch_shape=patch_shape, with_label_channels=True ) self.assertEqual(ds._ndim, 3) expected_raw_shape = (1,) + patch_shape expected_label_shape = (3,) + patch_shape for i in range(10): x, y = ds[i] self.assertEqual(x.shape, expected_raw_shape) self.assertEqual(y.shape, expected_label_shape)