def __init__(self, tranforms): self.tranform_list = [] for tranform in tranforms: if 'LoadImaged' == tranform: self.tranform_list.append(LoadImaged(keys=["image", "label"])) elif 'AsChannelFirstd' == tranform: self.tranform_list.append(AsChannelFirstd(keys="image")) elif 'ConvertToMultiChannelBasedOnBratsClassesd' == tranform: self.tranform_list.append(ConvertToMultiChannelBasedOnBratsClassesd(keys="label")) elif 'Spacingd' == tranform: self.tranform_list.append(Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"))) elif 'Orientationd' == tranform: self.tranform_list.append(Orientationd(keys=["image", "label"], axcodes="RAS")) elif 'CenterSpatialCropd' == tranform: self.tranform_list.append(CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64])) elif 'NormalizeIntensityd' == tranform: self.tranform_list.append(NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True)) elif 'ToTensord' == tranform: self.tranform_list.append(ToTensord(keys=["image", "label"])) elif 'Activations' == tranform: self.tranform_list.append(Activations(sigmoid=True)) elif 'AsDiscrete' == tranform: self.tranform_list.append(AsDiscrete(threshold_values=True)) else: raise ValueError( f"Unsupported tranform: {tranform}. Please add it to support it." ) super().__init__(self.tranform_list)
def pre_transforms(self, data=None): t = [ LoadImaged(keys="image", reader="ITKReader"), EnsureChannelFirstd(keys="image"), Orientationd(keys="image", axcodes="RAS"), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), ] if self.type == InferType.DEEPEDIT: t.extend( [ AddGuidanceFromPointsCustomd(ref_image="image", guidance="guidance", label_names=self.labels), Resized(keys="image", spatial_size=self.spatial_size, mode="area"), ResizeGuidanceMultipleLabelCustomd(guidance="guidance", ref_image="image"), AddGuidanceSignalCustomd( keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch ), ] ) else: t.extend( [ Resized(keys="image", spatial_size=self.spatial_size, mode="area"), DiscardAddGuidanced( keys="image", label_names=self.labels, number_intensity_ch=self.number_intensity_ch ), ] ) t.append(EnsureTyped(keys="image", device=data.get("device") if data else None)) return t
def pre_transforms(self, data): return [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), AddBackgroundScribblesFromROId( scribbles="label", scribbles_bg_label=self.scribbles_bg_label, scribbles_fg_label=self.scribbles_fg_label, ), # at the moment optimisers are bottleneck taking a long time, # therefore scaling non-isotropic with big spacing Spacingd(keys=["image", "label"], pixdim=self.pix_dim, mode=["bilinear", "nearest"]), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys="image", a_min=self.intensity_range[0], a_max=self.intensity_range[1], b_min=self.intensity_range[2], b_max=self.intensity_range[3], clip=self.intensity_range[4], ), MakeLikelihoodFromScribblesHistogramd( image="image", scribbles="label", post_proc_label="prob", scribbles_bg_label=self.scribbles_bg_label, scribbles_fg_label=self.scribbles_fg_label, normalise=True, ), ]
def get_seg_transforms(end_seg_axcodes): seg_transforms = Compose([ LoadNiftid(keys=["image"], as_closest_canonical=False), AddChanneld(keys=["image"]), Orientationd(keys=["image"], axcodes=end_seg_axcodes), ]) return seg_transforms
def val_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), reader="ITKReader"), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), EnsureChannelFirstd(keys=("image", "label")), Orientationd(keys=["image", "label"], axcodes="RAS"), # This transform may not work well for MR images ScaleIntensityRanged(keys=("image"), a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), Resized(keys=("image", "label"), spatial_size=self.spatial_size, mode=("area", "nearest")), # Transforms for click simulation FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"), AddInitialSeedPointMissingLabelsd(keys="label", guidance="guidance", sids="sids"), AddGuidanceSignalCustomd( keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch), # ToTensord(keys=("image", "label")), SelectItemsd(keys=("image", "label", "guidance", "label_names")), ]
def test_orntd_no_metadata(self): data = {"seg": np.ones((2, 1, 2, 3))} ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S"))
def test_orntd(self): data = {'seg': np.ones((2, 1, 2, 3)), 'seg.affine': np.eye(4)} ornt = Orientationd(keys='seg', axcodes='RAS') res = ornt(data) np.testing.assert_allclose(res['seg'].shape, (2, 1, 2, 3)) code = nib.aff2axcodes(res['seg.affine'], ornt.ornt_transform.labels) self.assertEqual(code, ('R', 'A', 'S'))
def get_transforms(self): self.logger.info("Getting transforms...") # Setup transforms of data sets train_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), NormalizeIntensityd(keys=["image"]), SpatialPadd(keys=["image", "label"], spatial_size=self.pad_crop_shape), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandSpatialCropd(keys=["image", "label"], roi_size=self.pad_crop_shape, random_center=True, random_size=False), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), NormalizeIntensityd(keys=["image"]), SpatialPadd(keys=["image", "label"], spatial_size=self.pad_crop_shape), RandSpatialCropd( keys=["image", "label"], roi_size=self.pad_crop_shape, random_center=True, random_size=False, ), ToTensord(keys=["image", "label"]), ]) test_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), NormalizeIntensityd(keys=["image"]), ToTensord(keys=["image", "label"]), ]) return train_transforms, val_transforms, test_transforms
def _default_transforms(image_key, label_key, pixdim): keys = [image_key] if label_key is None else [image_key, label_key] mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST ] if len(keys) == 2 else [GridSampleMode.BILINEAR] return Compose([ LoadImaged(keys=keys), AsChannelFirstd(keys=keys), Spacingd(keys=keys, pixdim=pixdim, mode=mode), Orientationd(keys=keys, axcodes="RAS"), ])
def test_load_spacingd_rotate_non_diag_ornt(self): data = {'image': FILES[0]} data_dict = LoadNiftid(keys='image')(data) data_dict = AddChanneld(keys='image')(data_dict) res_dict = Spacingd(keys='image', pixdim=(1, 2, 3), diagonal=False, mode='nearest')(data_dict) res_dict = Orientationd(keys='image', axcodes='LPI')(res_dict) np.testing.assert_allclose(data_dict['image.affine'], res_dict['image.original_affine']) np.testing.assert_allclose( res_dict['image.affine'], np.array([[-1., 0., 0., 32.], [0., -2., 0., 40.], [0., 0., -3., 32.], [0., 0., 0., 1.]]))
def _get_transforms(keys, pixdim): mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST ] if len(keys) == 2 else [GridSampleMode.BILINEAR] transforms = [ LoadImaged(keys=keys), AsChannelFirstd(keys=keys), Spacingd(keys=keys, pixdim=pixdim, mode=mode), Orientationd(keys=keys, axcodes="RAS"), ] return Compose(transforms)
def get_test_transforms(end_image_shape): test_transforms = Compose([ LoadNiftid(keys=["image"]), AddChanneld(keys=["image"]), Orientationd(keys=["image"], axcodes="RAS"), Winsorized(keys=["image"]), NormalizeIntensityd(keys=["image"]), ScaleIntensityd(keys=["image"]), SpatialPadd(keys=["image"], spatial_size=end_image_shape), ToTensord(keys=["image"]), ]) return test_transforms
def test_orntd(self): data = { "seg": np.ones((2, 1, 2, 3)), PostFix.meta("seg"): { "affine": np.eye(4) } } ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S"))
def get_xforms_load(mode="load", keys=("image", "label")): """returns a composed transform for train/val/infer.""" xforms = [ LoadImaged(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]), # ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), ] if mode == "load": dtype = (np.int16, np.uint8) xforms.extend([CastToTyped(keys, dtype=dtype), ToTensord(keys)]) return monai.transforms.Compose(xforms)
def test_load_spacingd_rotate_non_diag_ornt(self): data = {"image": FILES[0]} data_dict = LoadImaged(keys="image")(data) data_dict = AddChanneld(keys="image")(data_dict) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) res_dict = Orientationd(keys="image", axcodes="LPI")(res_dict) np.testing.assert_allclose( res_dict[PostFix.meta("image")]["affine"], np.array([[-1.0, 0.0, 0.0, 32.0], [0.0, -2.0, 0.0, 40.0], [0.0, 0.0, -3.0, 32.0], [0.0, 0.0, 0.0, 1.0]]), )
def test_load_spacingd_non_diag_ornt(self): data = {'image': FILES[1]} data_dict = LoadNiftid(keys='image')(data) data_dict = AddChanneld(keys='image')(data_dict) affine = data_dict['image.affine'] data_dict['image.original_affine'] = data_dict['image.affine'] = \ np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine res_dict = Spacingd(keys='image', pixdim=(1, 2, 3), diagonal=False, mode='constant')(data_dict) res_dict = Orientationd(keys='image', axcodes='LPI')(res_dict) np.testing.assert_allclose(data_dict['image.affine'], res_dict['image.original_affine']) np.testing.assert_allclose( res_dict['image.affine'], np.array([[-3., 0., 0., 56.4005909], [0., -2., 0., 52.02241516], [0., 0., -1., 35.29789734], [0., 0., 0., 1.]]))
def test_orntd_2d(self): data = { "seg": np.ones((2, 1, 3)), "img": np.ones((2, 1, 3)), "seg.affine": np.eye(4), "img.affine": np.eye(4) } ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") res = ornt(data) np.testing.assert_allclose(res["img"].shape, (2, 3, 1)) code = nib.aff2axcodes(res["seg.affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "S")) code = nib.aff2axcodes(res["img.affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "S"))
def test_orntd_1d(self): data = { "seg": np.ones((2, 3)), "img": np.ones((2, 3)), "seg_meta_dict": {"affine": np.eye(4)}, "img_meta_dict": {"affine": np.eye(4)}, } ornt = Orientationd(keys=("img", "seg"), axcodes="L") res = ornt(data) np.testing.assert_allclose(res["img"].shape, (2, 3)) code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("L", "A", "S")) code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("L", "A", "S"))
def test_load_spacingd_rotate_non_diag_ornt(self): data_dict = self.load_image(FILES[0]) t = Compose([ Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border"), Orientationd(keys="image", axcodes="LPI"), ]) res_dict = t(data_dict) np.testing.assert_allclose( res_dict["image"].affine, np.array([[-1.0, 0.0, 0.0, 32.0], [0.0, -2.0, 0.0, 40.0], [0.0, 0.0, -3.0, 32.0], [0.0, 0.0, 0.0, 1.0]]), )
def test_orntd_canonical(self): data = { "seg": np.ones((2, 1, 2, 3)), "img": np.ones((2, 1, 2, 3)), "seg.affine": np.eye(4), "img.affine": np.eye(4), } ornt = Orientationd(keys=("img", "seg"), as_closest_canonical=True) res = ornt(data) np.testing.assert_allclose(res["img"].shape, (2, 1, 2, 3)) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) code = nib.aff2axcodes(res["seg.affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) code = nib.aff2axcodes(res["img.affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S"))
def get_xforms(args, mode="train", keys=("image", "label")): """returns a composed transform for train/val/infer.""" xforms = [ LoadNiftid(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[:len(keys)]), ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), ] if mode == "train": xforms.extend([ SpatialPadd(keys, spatial_size=(args.patch_size, args.patch_size, -1), mode="reflect"), # ensure at least 192x192 RandAffined( keys, prob=0.15, rotate_range=(-0.05, 0.05), scale_range=(-0.1, 0.1), mode=("bilinear", "nearest"), as_tensor_output=False, ), RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(args.patch_size, args.patch_size, args.n_slice), num_samples=3), RandGaussianNoised(keys[0], prob=0.15, std=0.01), RandFlipd(keys, spatial_axis=0, prob=0.5), RandFlipd(keys, spatial_axis=1, prob=0.5), RandFlipd(keys, spatial_axis=2, prob=0.5), ]) dtype = (np.float32, np.uint8) if mode == "val": dtype = (np.float32, np.uint8) if mode == "infer": dtype = (np.float32, ) xforms.extend([CastToTyped(keys, dtype=dtype), ToTensord(keys)]) return monai.transforms.Compose(xforms)
def test_load_spacingd_rotate_non_diag_ornt(self): data = {"image": FILES[0]} data_dict = LoadNiftid(keys="image")(data) data_dict = AddChanneld(keys="image")(data_dict) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, mode="nearest")(data_dict) res_dict = Orientationd(keys="image", axcodes="LPI")(res_dict) np.testing.assert_allclose(data_dict["image.affine"], res_dict["image.original_affine"]) np.testing.assert_allclose( res_dict["image.affine"], np.array([[-1.0, 0.0, 0.0, 32.0], [0.0, -2.0, 0.0, 40.0], [0.0, 0.0, -3.0, 32.0], [0.0, 0.0, 0.0, 1.0]]), )
def test_orntd(self, init_param, img: torch.Tensor, affine: Optional[torch.Tensor], expected_shape, expected_code, device): ornt = Orientationd(**init_param) if affine is not None: img = MetaTensor(img, affine=affine) img = img.to(device) data = {k: img.clone() for k in ornt.keys} res = ornt(data) for k in ornt.keys: _im = res[k] self.assertIsInstance(_im, MetaTensor) np.testing.assert_allclose(_im.shape, expected_shape) code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) self.assertEqual("".join(code), expected_code)
def main(): images = sorted(glob(os.path.join(IMAGE_FOLDER, "case*.nii.gz"))) val_files = [{"img": img} for img in images] # define transforms for image and segmentation infer_transforms = Compose([ LoadNiftid("img"), AddChanneld("img"), Orientationd( "img", "SPL"), # coplenet works on the plane defined by the last two axes ToTensord("img"), ]) test_ds = monai.data.Dataset(data=val_files, transform=infer_transforms) # sliding window inference need to input 1 image in every iteration data_loader = torch.utils.data.DataLoader( test_ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CopleNet().to(device) model.load_state_dict(torch.load(MODEL_FILE)["model_state_dict"]) model.eval() with torch.no_grad(): saver = NiftiSaver(output_dir=OUTPUT_FOLDER) for idx, val_data in enumerate(data_loader): print(f"Inference on {idx+1} of {len(data_loader)}") val_images = val_data["img"].to(device) # define sliding window size and batch size for windows inference slice_shape = np.ceil(np.asarray(val_images.shape[3:]) / 32) * 32 roi_size = (20, int(slice_shape[0]), int(slice_shape[1])) sw_batch_size = 2 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model, 0.0, padding_mode="circular") # val_outputs = (val_outputs.sigmoid() >= 0.5).float() val_outputs = val_outputs.argmax(dim=1, keepdim=True) saver.save_batch(val_outputs, val_data["img_meta_dict"])
def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device): set_track_meta(track_meta) ornt = Orientationd(**init_param) img = img.to(device) expected_shape = img.shape expected_code = ornt.ornt_transform.axcodes data = {k: img.clone() for k in ornt.keys} res = ornt(data) for k in ornt.keys: _im = res[k] np.testing.assert_allclose(_im.shape, expected_shape) if track_meta: self.assertIsInstance(_im, MetaTensor) code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) self.assertEqual("".join(code), expected_code) else: self.assertIsInstance(_im, torch.Tensor) self.assertNotIsInstance(_im, MetaTensor)
def get_xforms_with_synthesis(mode="synthesis", keys=("image", "label"), keys2=("image", "label", "synthetic_lesion")): """returns a composed transform for train/val/infer.""" xforms = [ LoadImaged(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]), ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), CopyItemsd(keys,1, names=['image_1', 'label_1']), ] if mode == "synthesis": xforms.extend([ SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"), # ensure at least 192x192 RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=3), TransCustom(keys, path_synthesis, read_cea_aug_slice2, pseudo_healthy_with_texture, scans_syns, decreasing_sequence, GEN=15, POST_PROCESS=True, mask_outer_ring=True, new_value=.5), RandAffined( # keys, keys2, prob=0.15, rotate_range=(0.05, 0.05, None), # 3 parameters control the transform on 3 dimensions scale_range=(0.1, 0.1, None), mode=("bilinear", "nearest", "bilinear"), # mode=("bilinear", "nearest"), as_tensor_output=False ), RandGaussianNoised((keys2[0],keys2[2]), prob=0.15, std=0.01), # RandGaussianNoised(keys[0], prob=0.15, std=0.01), RandFlipd(keys, spatial_axis=0, prob=0.5), RandFlipd(keys, spatial_axis=1, prob=0.5), RandFlipd(keys, spatial_axis=2, prob=0.5), TransCustom2(0.333) ]) dtype = (np.float32, np.uint8) # dtype = (np.float32, np.uint8, np.float32) xforms.extend([CastToTyped(keys, dtype=dtype)]) return monai.transforms.Compose(xforms)
def test_orntd_3d(self): for p in TEST_NDARRAYS: data = { "seg": p(np.ones((2, 1, 2, 3))), "img": p(np.ones((2, 1, 2, 3))), PostFix.meta("seg"): { "affine": np.eye(4) }, PostFix.meta("img"): { "affine": np.eye(4) }, } ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") res = ornt(data) np.testing.assert_allclose(res["img"].shape, (2, 2, 1, 3)) np.testing.assert_allclose(res["seg"].shape, (2, 2, 1, 3)) code = nib.aff2axcodes(res[PostFix.meta("seg")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "I")) code = nib.aff2axcodes(res[PostFix.meta("img")]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "I"))
def test_load_spacingd_non_diag_ornt(self): data = {"image": FILES[1]} data_dict = LoadNiftid(keys="image")(data) data_dict = AddChanneld(keys="image")(data_dict) affine = data_dict["image_meta"]["affine"] data_dict["image_meta"]["original_affine"] = data_dict["image_meta"][ "affine"] = (np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, mode="border")(data_dict) res_dict = Orientationd(keys="image", axcodes="LPI")(res_dict) np.testing.assert_allclose( res_dict["image_meta"]["affine"], np.array([ [-3.0, 0.0, 0.0, 56.4005909], [0.0, -2.0, 0.0, 52.02241516], [0.0, 0.0, -1.0, 35.29789734], [0.0, 0.0, 0.0, 1.0], ]), )
def test_load_spacingd_non_diag_ornt(self): data_dict = self.load_image(FILES[1]) affine = data_dict["image"].affine data_dict["image"].meta["original_affine"] = data_dict[ "image"].affine = (torch.tensor( [[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]], dtype=torch.float64) @ affine) t = Compose([ Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border"), Orientationd(keys="image", axcodes="LPI"), ]) res_dict = t(data_dict) np.testing.assert_allclose( res_dict["image"].affine, np.array([ [-3.0, 0.0, 0.0, 56.4005909], [0.0, -2.0, 0.0, 52.02241516], [0.0, 0.0, -1.0, 35.29789734], [0.0, 0.0, 0.0, 1.0], ]), )
def get_xforms_scans_or_synthetic_lesions(mode="scans", keys=("image", "label")): """returns a composed transform for scans or synthetic lesions.""" xforms = [ LoadImaged(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[:len(keys)]), ] dtype = (np.int16, np.uint8) if mode == "synthetic": xforms.extend([ ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), ]) dtype = (np.float32, np.uint8) xforms.extend([CastToTyped(keys, dtype=dtype)]) return monai.transforms.Compose(xforms)