def test_correct(self): with tempfile.TemporaryDirectory() as temp_dir: transforms = Compose([ LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2", "im2_meta_dict"), names=("im3", "im3_meta_dict")), ResampleToMatchd("im3", "im1_meta_dict"), Lambda(update_fname), SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False), ]) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) # and that the meta data has been updated accordingly assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) assert_allclose(data["im3_meta_dict"]["affine"], data["im1_meta_dict"]["affine"]) # check we're different from the original self.assertTrue( any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) self.assertTrue( any(i != j for i, j in zip(data["im3_meta_dict"]["affine"].flatten( ), data["im2_meta_dict"]["affine"].flatten()))) # test the inverse data = Invertd("im3", transforms, "im3")(data) assert_allclose(data["im2"].shape, data["im3"].shape)
def test_correct(self): transforms = Compose([ LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2"), names=("im3")), ResampleToMatchd("im3", "im1"), Lambda(update_fname), SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, resample=False), ]) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) # and that the meta data has been updated accordingly assert_allclose(data["im3"].affine, data["im1"].affine) # check we're different from the original self.assertTrue( any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) self.assertTrue( any(i != j for i, j in zip(data["im3"].affine.flatten(), data["im2"].affine.flatten()))) # test the inverse data = Invertd("im3", transforms)(data) assert_allclose(data["im2"].shape, data["im3"].shape)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], meta_keys=[ PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None ], orig_meta_keys=[ PostFix.meta("label"), PostFix.meta("label"), None ], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], meta_keys=[ PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1") ], orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], nearest_interp=[True, False], to_tensor=[True, True], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", PostFix.meta("image_inverted1"), PostFix.meta("image_inverted"), PostFix.meta("image"), "image_transforms", "label", "label_inverted", "label_inverted1", PostFix.meta("label_inverted1"), PostFix.meta("label_inverted"), PostFix.meta("label"), "label_transforms", "test_dict", "test_dict_transforms", ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue( isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue( isinstance(item["test_dict"]["filename_or_obj"], str)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def __init__( self, transform: InvertibleTransform, loader: TorchDataLoader, output_keys: Union[str, Sequence[str]] = CommonKeys.PRED, batch_keys: Union[str, Sequence[str]] = CommonKeys.IMAGE, meta_key_postfix: str = "meta_dict", collate_fn: Optional[Callable] = no_collation, postfix: str = "inverted", nearest_interp: Union[bool, Sequence[bool]] = True, to_tensor: Union[bool, Sequence[bool]] = True, device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", post_func: Union[Callable, Sequence[Callable]] = lambda x: x, num_workers: Optional[int] = 0, ) -> None: """ Args: transform: a callable data transform on input data. loader: data loader used to run transforms and generate the batch of data. output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it. it also can be a list of keys, will invert transform for each of them. Default to "pred". batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms for this input data, then invert them for the expected data with `output_keys`. It can also be a list of keys, each matches to the `output_keys` data. default to "image". meta_key_postfix: use `{batch_key}_{postfix}` to to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. collate_fn: how to collate data after inverse transformations. default won't do any collation, so the output will be a list of size batch size. postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}_{postfix}`. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. it also can be a list of bool, each matches to the `output_keys` data. to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. it also can be a list of bool, each matches to the `output_keys` data. device: if converted to Tensor, move the inverted results to target device before `post_func`, default to "cpu", it also can be a list of string or `torch.device`, each matches to the `output_keys` data. post_func: post processing for the inverted data, should be a callable function. it also can be a list of callable, each matches to the `output_keys` data. num_workers: number of workers when run data loader for inverse transforms, default to 0 as only run one iteration and multi-processing may be even slower. Set to `None`, to use the `num_workers` of the input transform data loader. """ self.inverter = Invertd( keys=output_keys, transform=transform, loader=loader, orig_keys=batch_keys, meta_key_postfix=meta_key_postfix, collate_fn=collate_fn, postfix=postfix, nearest_interp=nearest_interp, to_tensor=to_tensor, device=device, post_func=post_func, num_workers=num_workers, ) self.output_keys = ensure_tuple(output_keys) self.meta_key_postfix = meta_key_postfix self.postfix = postfix
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( keys=["image", "label"], transform=transform, loader=loader, orig_keys="label", nearest_interp=True, postfix="inverted", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ) # execute 1 epoch for d in loader: d = inverter(d) # this unit test only covers basic function, test_handler_transform_inverter covers more self.assertTupleEqual(d["image"].shape[1:], (1, 100, 100, 100)) self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100)) # check the nearest inerpolation mode for i in d["image_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in d["label_inverted"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) set_determinism(seed=None)
def test_value_3d( self, keys, data, expected_convert_result, expected_zoom_result, expected_zoom_keepsize_result, expected_flip_result, expected_clip_result, expected_rotate_result, ): test_dtype = [torch.float32] for dtype in test_dtype: data = CastToTyped(keys=["image", "boxes"], dtype=dtype)(data) # test ConvertBoxToStandardModed transform_convert_mode = ConvertBoxModed(**keys) convert_result = transform_convert_mode(data) assert_allclose(convert_result["boxes"], expected_convert_result, type_test=True, device_test=True, atol=1e-3) invert_transform_convert_mode = Invertd( keys=["boxes"], transform=transform_convert_mode, orig_keys=["boxes"]) data_back = invert_transform_convert_mode(convert_result) if "boxes_transforms" in data_back: # if the transform is tracked in dict: self.assertEqual(data_back["boxes_transforms"], []) # it should be updated assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) # test ZoomBoxd transform_zoom = ZoomBoxd(image_keys="image", box_keys="boxes", box_ref_image_keys="image", zoom=[0.5, 3, 1.5], keep_size=False) zoom_result = transform_zoom(data) self.assertEqual(len(zoom_result["image"].applied_operations), 1) assert_allclose(zoom_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=1e-3) invert_transform_zoom = Invertd(keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"]) data_back = invert_transform_zoom(zoom_result) self.assertEqual(data_back["image"].applied_operations, []) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) transform_zoom = ZoomBoxd(image_keys="image", box_keys="boxes", box_ref_image_keys="image", zoom=[0.5, 3, 1.5], keep_size=True) zoom_result = transform_zoom(data) self.assertEqual(len(zoom_result["image"].applied_operations), 1) assert_allclose(zoom_result["boxes"], expected_zoom_keepsize_result, type_test=True, device_test=True, atol=1e-3) # test RandZoomBoxd transform_zoom = RandZoomBoxd( image_keys="image", box_keys="boxes", box_ref_image_keys="image", prob=1.0, min_zoom=(0.3, ) * 3, max_zoom=(3.0, ) * 3, keep_size=False, ) zoom_result = transform_zoom(data) self.assertEqual(len(zoom_result["image"].applied_operations), 1) invert_transform_zoom = Invertd(keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"]) data_back = invert_transform_zoom(zoom_result) self.assertEqual(data_back["image"].applied_operations, []) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) # test AffineBoxToImageCoordinated, AffineBoxToWorldCoordinated transform_affine = AffineBoxToImageCoordinated( box_keys="boxes", box_ref_image_keys="image") if not isinstance( data["image"], MetaTensor ): # metadict should be undefined and it's an exception with self.assertRaises(Exception) as context: transform_affine(deepcopy(data)) self.assertTrue( "Please check whether it is the correct the image meta key." in str(context.exception)) data["image"] = MetaTensor( data["image"], meta={ "affine": torch.diag(1.0 / torch.Tensor([0.5, 3, 1.5, 1])) }) affine_result = transform_affine(data) if "boxes_transforms" in affine_result: self.assertEqual(len(affine_result["boxes_transforms"]), 1) assert_allclose(affine_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=0.01) invert_transform_affine = Invertd(keys=["boxes"], transform=transform_affine, orig_keys=["boxes"]) data_back = invert_transform_affine(affine_result) if "boxes_transforms" in data_back: self.assertEqual(data_back["boxes_transforms"], []) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) invert_transform_affine = AffineBoxToWorldCoordinated( box_keys="boxes", box_ref_image_keys="image") data_back = invert_transform_affine(affine_result) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) # test FlipBoxd transform_flip = FlipBoxd(image_keys="image", box_keys="boxes", box_ref_image_keys="image", spatial_axis=[0, 1, 2]) flip_result = transform_flip(data) if "boxes_transforms" in flip_result: self.assertEqual(len(flip_result["boxes_transforms"]), 1) assert_allclose(flip_result["boxes"], expected_flip_result, type_test=True, device_test=True, atol=1e-3) invert_transform_flip = Invertd(keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"]) data_back = invert_transform_flip(flip_result) if "boxes_transforms" in data_back: self.assertEqual(data_back["boxes_transforms"], []) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) # test RandFlipBoxd for spatial_axis in [(0, ), (1, ), (2, ), (0, 1), (1, 2)]: transform_flip = RandFlipBoxd( image_keys="image", box_keys="boxes", box_ref_image_keys="image", prob=1.0, spatial_axis=spatial_axis, ) flip_result = transform_flip(data) if "boxes_transforms" in flip_result: self.assertEqual(len(flip_result["boxes_transforms"]), 1) invert_transform_flip = Invertd(keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"]) data_back = invert_transform_flip(flip_result) if "boxes_transforms" in data_back: self.assertEqual(data_back["boxes_transforms"], []) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) # test ClipBoxToImaged transform_clip = ClipBoxToImaged(box_keys="boxes", box_ref_image_keys="image", label_keys=["labels", "scores"], remove_empty=True) clip_result = transform_clip(data) assert_allclose(clip_result["boxes"], expected_clip_result, type_test=True, device_test=True, atol=1e-3) assert_allclose(clip_result["labels"], data["labels"][1:], type_test=True, device_test=True, atol=1e-3) assert_allclose(clip_result["scores"], data["scores"][1:], type_test=True, device_test=True, atol=1e-3) transform_clip = ClipBoxToImaged( box_keys="boxes", box_ref_image_keys="image", label_keys=[], remove_empty=True) # corner case when label_keys is empty clip_result = transform_clip(data) assert_allclose(clip_result["boxes"], expected_clip_result, type_test=True, device_test=True, atol=1e-3) # test RandCropBoxByPosNegLabeld transform_crop = RandCropBoxByPosNegLabeld( image_keys="image", box_keys="boxes", label_keys=["labels", "scores"], spatial_size=2, num_samples=3) crop_result = transform_crop(data) assert len(crop_result) == 3 for ll in range(3): assert_allclose( crop_result[ll]["boxes"].shape[0], crop_result[ll]["labels"].shape[0], type_test=True, device_test=True, atol=1e-3, ) assert_allclose( crop_result[ll]["boxes"].shape[0], crop_result[ll]["scores"].shape[0], type_test=True, device_test=True, atol=1e-3, ) # test RotateBox90d transform_rotate = RotateBox90d(image_keys="image", box_keys="boxes", box_ref_image_keys="image", k=1, spatial_axes=[0, 1]) rotate_result = transform_rotate(data) self.assertEqual(len(rotate_result["image"].applied_operations), 1) assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=True, device_test=True, atol=1e-3) invert_transform_rotate = Invertd(keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"]) data_back = invert_transform_rotate(rotate_result) self.assertEqual(data_back["image"].applied_operations, []) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) transform_rotate = RandRotateBox90d(image_keys="image", box_keys="boxes", box_ref_image_keys="image", prob=1.0, max_k=3, spatial_axes=[0, 1]) rotate_result = transform_rotate(data) self.assertEqual(len(rotate_result["image"].applied_operations), 1) invert_transform_rotate = Invertd(keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"]) data_back = invert_transform_rotate(rotate_result) self.assertEqual(data_back["image"].applied_operations, []) assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3)
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), EnsureTyped(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), Invertd( keys= "pred", # invert the `pred` data field, also support multiple fields transform=pre_transforms, orig_keys= "img", # get the previously applied pre_transforms information on the `img` data field, # then invert `pred` based on this information. we can use same info # for multiple fields, also support different orig_keys for different fields meta_keys= "pred_meta_dict", # key field to save inverted meta data, every item maps to `keys` orig_meta_keys= "img_meta_dict", # get the meta data from `img_meta_dict` field when inverting, # for example, may need the `affine` to invert `Spacingd` transform, # multiple fields can use the same meta data to invert meta_key_postfix= "meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key, # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}", # otherwise, no need this arg during inverting nearest_interp= False, # don't change the interpolation mode to "nearest" when inverting transforms # to ensure a smooth output, then execute `AsDiscreted` transform to_tensor=True, # convert to PyTorch Tensor after inverting ), AsDiscreted(keys="pred", threshold=0.5), SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # decollate the batch data into a list of dictionaries, then execute postprocessing transforms d = [post_transforms(i) for i in decollate_batch(d)]
def __init__( self, transform: InvertibleTransform, loader: TorchDataLoader, output_keys: KeysCollection = CommonKeys.PRED, batch_keys: KeysCollection = CommonKeys.IMAGE, meta_keys: Optional[KeysCollection] = None, batch_meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", collate_fn: Optional[Callable] = no_collation, nearest_interp: Union[bool, Sequence[bool]] = True, to_tensor: Union[bool, Sequence[bool]] = True, device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", post_func: Union[Callable, Sequence[Callable]] = lambda x: x, num_workers: Optional[int] = 0, ) -> None: """ Args: transform: a callable data transform on input data. loader: data loader used to run transforms and generate the batch of data. output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it. it also can be a list of keys, will invert transform for each of them. Default to "pred". it's in-place operation. batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms for this input data, then invert them for the expected data with `output_keys`. It can also be a list of keys, each matches to the `output_keys` data. default to "image". meta_keys: explicitly indicate the key for the inverted meta data dictionary. the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`. batch_meta_keys: the key of the meta data of input data in `ignite.engine.batch`, will get the `affine`, `data_shape`, etc. the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. meta data will also be inverted and stored in `meta_keys`. meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. default is `meta_dict`, the meta data is a dictionary object. For example, to handle orig_key `image`, read/write `affine` matrices from the metadata `image_meta_dict` dictionary's `affine` field. the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". collate_fn: how to collate data after inverse transformations. default won't do any collation, so the output will be a list of PyTorch Tensor or numpy array without batch dim. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. it also can be a list of bool, each matches to the `output_keys` data. to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. it also can be a list of bool, each matches to the `output_keys` data. device: if converted to Tensor, move the inverted results to target device before `post_func`, default to "cpu", it also can be a list of string or `torch.device`, each matches to the `output_keys` data. post_func: post processing for the inverted data, should be a callable function. it also can be a list of callable, each matches to the `output_keys` data. num_workers: number of workers when run data loader for inverse transforms, default to 0 as only run one iteration and multi-processing may be even slower. Set to `None`, to use the `num_workers` of the input transform data loader. """ self.inverter = Invertd( keys=output_keys, transform=transform, loader=loader, orig_keys=batch_keys, meta_keys=meta_keys, orig_meta_keys=batch_meta_keys, meta_key_postfix=meta_key_postfix, collate_fn=collate_fn, nearest_interp=nearest_interp, to_tensor=to_tensor, device=device, post_func=post_func, num_workers=num_workers, ) self.output_keys = ensure_tuple(output_keys) self.meta_keys = ensure_tuple_rep(None, len( self.output_keys)) if meta_keys is None else ensure_tuple( meta_keys) if len(self.output_keys) != len(self.meta_keys): raise ValueError( "meta_keys should have the same length as output_keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.output_keys))
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), ToTensord(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), Invertd(keys="pred", transform=pre_transforms, loader=dataloader, orig_keys="img", nearest_interp=True), SaveImaged(keys="pred_inverted", output_dir="./output", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # execute post transforms to invert spatial transforms and save to NIfTI files post_transforms(d)