def test_inverse_array(self, use_compose, dtype, device): img: MetaTensor tr = Compose([ AddChannel(), Orientation("RAS"), Flip(1), Spacing([1.0, 1.2, 0.9], align_corners=False) ]) num_invertible = len( [i for i in tr.transforms if isinstance(i, InvertibleTransform)]) # forward img = tr(self.get_image(dtype, device)) self.assertEqual(len(img.applied_operations), num_invertible) # inverse with Compose if use_compose: img = tr.inverse(img) self.assertEqual(len(img.applied_operations), 0) # inverse individually else: _tr: InvertibleTransform num_to_inverse = num_invertible for _tr in tr.transforms[::-1]: if isinstance(_tr, InvertibleTransform): img = _tr.inverse(img) num_to_inverse -= 1 self.assertEqual(len(img.applied_operations), num_to_inverse)
def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform ]) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1, )).to(device) data = first(loader) self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) labels = data["label"].to(device) self.assertIsInstance(labels, MetaTensor) segs = model(labels).detach().cpu() segs_decollated = decollate_batch(segs) self.assertIsInstance(segs_decollated[0], MetaTensor) # inverse of individual segmentation seg_metatensor = first(segs_decollated) # test to convert interpolation mode for 1 data of model output batch convert_applied_interp_mode(seg_metatensor.applied_operations, mode="nearest", align_corners=None) # manually invert the last crop samples xform = seg_metatensor.applied_operations.pop(-1) shape_before_extra_xform = xform["orig_size"] resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform) with resizer.trace_transform(False): seg_metatensor = resizer(seg_metatensor) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse({"label": seg_metatensor})["label"] self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
def test_inverse_inferred_seg(self): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform != "darwin" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99)) ]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2, ), ).to(device) data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value segs_dict = { "label": segs, label_transform_key: data[label_transform_key] } segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
def test_transforms(self): key = "im" _, im = self.get_im() tr = Compose([ToMetaTensord(key), BorderPadd(key, 1), DivisiblePadd(key, 16), FromMetaTensord(key)]) num_tr = len(tr.transforms) data = {key: im, PostFix.meta(key): {"affine": torch.eye(4)}} # apply one at a time for i, _tr in enumerate(tr.transforms): data = _tr(data) is_meta = isinstance(_tr, (ToMetaTensord, BorderPadd, DivisiblePadd)) if is_meta: self.assertEqual(len(data), 1 if not config.USE_META_DICT else 2) # im, im_transforms, compatibility self.assertIsInstance(data[key], MetaTensor) n_applied = len(data[key].applied_operations) else: self.assertEqual(len(data), 3) # im, im_meta_dict, im_transforms self.assertIsInstance(data[key], torch.Tensor) self.assertNotIsInstance(data[key], MetaTensor) n_applied = len(data[PostFix.transforms(key)]) self.assertEqual(n_applied, i + 1) # inverse one at a time for i, _tr in enumerate(tr.transforms[::-1]): data = _tr.inverse(data) is_meta = isinstance(_tr, (FromMetaTensord, BorderPadd, DivisiblePadd)) if is_meta: self.assertEqual(len(data), 1) # im self.assertIsInstance(data[key], MetaTensor) n_applied = len(data[key].applied_operations) else: self.assertEqual(len(data), 3) # im, im_meta_dict, im_transforms self.assertIsInstance(data[key], torch.Tensor) self.assertNotIsInstance(data[key], MetaTensor) n_applied = len(data[PostFix.transforms(key)]) self.assertEqual(n_applied, num_tr - i - 1) # apply all in one go data = tr({key: im, PostFix.meta(key): {"affine": torch.eye(4)}}) self.assertEqual(len(data), 3) # im, im_meta_dict, im_transforms self.assertIsInstance(data[key], torch.Tensor) self.assertNotIsInstance(data[key], MetaTensor) n_applied = len(data[PostFix.transforms(key)]) self.assertEqual(n_applied, num_tr) # inverse all in one go data = tr.inverse(data) self.assertEqual(len(data), 3) # im, im_meta_dict, im_transforms self.assertIsInstance(data[key], torch.Tensor) self.assertNotIsInstance(data[key], MetaTensor) n_applied = len(data[PostFix.transforms(key)]) self.assertEqual(n_applied, 0)
def test_invert(self): set_determinism(seed=0) im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1]) # label image, discrete data = [im_fname for _ in range(12)] transform = Compose( [ LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RPS"), Spacing(pixdim=(1.2, 1.01, 0.9), mode="bilinear", dtype=np.float32), RandFlip(prob=0.5, spatial_axis=[1, 2]), RandAxisFlip(prob=0.5), RandRotate90(prob=0, spatial_axes=(1, 2)), RandZoom(prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotate(prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCrop(100), CastToType(dtype=torch.uint8), ] ) # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = Dataset(data, transform=transform) self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor) loader = DataLoader(dataset, num_workers=num_workers, batch_size=1) inverter = Invert(transform=transform, nearest_interp=True, device="cpu") for d in loader: d = decollate_batch(d) for item in d: orig = deepcopy(item) i = inverter(item) self.assertTupleEqual(orig.shape[1:], (100, 100, 100)) # check the nearest interpolation mode torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # check labels match reverted = i.detach().cpu().numpy().astype(np.int32) original = LoadImage(image_only=True)(data[-1]) n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3)) reverted_name = i.meta["filename_or_obj"] original_name = original.meta["filename_or_obj"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) self.assertTrue((reverted.size - n_good) < 300000, f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({"image": image, "label": label.astype(np.float32)}) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device) data = first(loader) self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) # test to convert interpolation mode for 1 data of model output batch convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) # Inverse of batch batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) with allow_missing_keys_mode(transforms): inv_batch = batch_inverter(segs_dict) self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS, image_only=True), EnsureChannelFirstd(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, prob=0, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = Dataset(data, transform=transform) transform.inverse(dataset[0]) loader = DataLoader(dataset, num_workers=num_workers, batch_size=1) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted"], transform=transform, orig_keys=["label", "label"], nearest_interp=True, device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], nearest_interp=[True, False], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", "label", "label_inverted", "label_inverted1" ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS, image_only=True)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item["label_inverted"].meta["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}") set_determinism(seed=None)