def train_pre_transforms(self, context: Context): # Dataset preparation t: List[Any] = [ LoadImaged(keys=("image", "label")), AddChanneld(keys=("image", "label")), SpatialCropForegroundd(keys=("image", "label"), source_key="label", spatial_size=self.roi_size), Resized(keys=("image", "label"), spatial_size=self.model_size, mode=("area", "nearest")), NormalizeIntensityd(keys="image", subtrahend=208.0, divisor=388.0), # type: ignore ] if self.dimension == 3: t.append(FindAllValidSlicesd(label="label", sids="sids")) t.extend([ AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys=("image", "label"), device=context.device), SelectItemsd(keys=("image", "label", "guidance")), ]) return t
def pre_transforms(self, data=None): t = [ LoadImaged(keys="image", reader="ITKReader"), EnsureChannelFirstd(keys="image"), Orientationd(keys="image", axcodes="RAS"), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), ] if self.type == InferType.DEEPEDIT: t.extend( [ AddGuidanceFromPointsCustomd(ref_image="image", guidance="guidance", label_names=self.labels), Resized(keys="image", spatial_size=self.spatial_size, mode="area"), ResizeGuidanceMultipleLabelCustomd(guidance="guidance", ref_image="image"), AddGuidanceSignalCustomd( keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch ), ] ) else: t.extend( [ Resized(keys="image", spatial_size=self.spatial_size, mode="area"), DiscardAddGuidanced( keys="image", label_names=self.labels, number_intensity_ch=self.number_intensity_ch ), ] ) t.append(EnsureTyped(keys="image", device=data.get("device") if data else None)) return t
def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), dtype=np.uint8), FilterImaged(keys="image", min_size=5), AsChannelFirstd(keys="image"), AddChanneld(keys="label"), ToTensord(keys="image"), TorchVisiond(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyd(keys="image"), RandRotate90d(keys=("image", "label"), prob=0.5, spatial_axes=(0, 1)), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), AddInitialSeedPointExd(label="label", guidance="guidance"), AddGuidanceSignald(image="image", guidance="guidance", number_intensity_ch=3), EnsureTyped(keys=("image", "label")), ]
def test_dict(self): # simulate complicated input data test_data = { "img": np.array([1.0, 2.0], dtype=np.float32), "meta": { "dims": 3, "size": np.array([1, 2, 3]), "path": "temp/test" }, "extra": None, } for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({ "data": test_data })["data"] self.assertTrue(isinstance(result, dict)) self.assertTrue( isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0])) self.assertTrue( isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3])) self.assertEqual(result["meta"]["path"], "temp/test") self.assertEqual(result["extra"], None)
def test_string(self): for dtype in ("tensor", "numpy"): # string input result = EnsureTyped(keys="data", data_type=dtype)({ "data": "test_string" })["data"] self.assertTrue(isinstance(result, str)) self.assertEqual(result, "test_string") # numpy array of string result = EnsureTyped(keys="data", data_type=dtype)({ "data": np.array(["test_string"]) })["data"] self.assertTrue(isinstance(result, np.ndarray)) self.assertEqual(result[0], "test_string")
def pre_transforms(self, data=None) -> Sequence[Callable]: t = [ LoadImaged(keys="image"), AsChannelFirstd(keys="image"), Spacingd(keys="image", pixdim=[1.0] * self.dimension, mode="bilinear"), AddGuidanceFromPointsd(ref_image="image", guidance="guidance", dimensions=self.dimension), ] if self.dimension == 2: t.append(Fetch2DSliced(keys="image", guidance="guidance")) t.extend([ AddChanneld(keys="image"), SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=self.spatial_size), Resized(keys="image", spatial_size=self.model_size, mode="area"), ResizeGuidanced(guidance="guidance", ref_image="image"), NormalizeIntensityd(keys="image", subtrahend=208, divisor=388), # type: ignore AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys="image", device=data.get("device") if data else None), ]) return t
def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label")), AddChanneld(keys=("image", "label")), Spacingd( keys=("image", "label"), pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), ScaleIntensityRanged(keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=("image", "label"), source_key="image"), EnsureTyped(keys=("image", "label"), device=context.device), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), SelectItemsd(keys=("image", "label")), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), SqueezeDimd(keys="pred", dim=0), ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), ToNumpyd(keys="pred"), RestoreLabeld(keys="pred", ref_image="image", mode="nearest"), AsChannelLastd(keys="pred"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), BoundingBoxd(keys="pred", result="result", bbox="bbox"), ]
def test_list_tuple(self): for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({ "data": [[1, 2], [3, 4]] })["data"] self.assertTrue(isinstance(result, list)) self.assertTrue( isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1][0], torch.as_tensor(3)) # tuple of numpy arrays result = EnsureTyped(keys="data", data_type=dtype)({ "data": (np.array([1, 2]), np.array([3, 4])) })["data"] self.assertTrue(isinstance(result, tuple)) self.assertTrue( isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4]))
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1), AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None), SqueezeDimd(keys="pred", dim=0), ToNumpyd(keys=("image", "pred")), PostFilterLabeld(keys="pred", image="image"), FindContoursd(keys="pred", labels=self.labels), ]
def pre_transforms(self, data=None): return [ LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8, padding=False), AsChannelFirstd(keys="image"), AddClickSignalsd(image="image"), EnsureTyped(keys="image", device=data.get("device") if data else None), ]
def train_post_transforms(self, context: Context): return [ EnsureTyped(keys="pred", device=context.device), Activationsd(keys="pred", softmax=len(self._labels) > 1, sigmoid=len(self._labels) == 1), AsDiscreted( keys=("pred", "label"), argmax=(True, False), to_onehot=(len(self._labels) + 1, len(self._labels) + 1), ), ]
def pre_transforms(self, data=None) -> Sequence[Callable]: return [ LoadImaged(keys="image", reader="ITKReader"), EnsureChannelFirstd(keys="image"), Spacingd(keys="image", pixdim=self.target_spacing), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), EnsureTyped(keys="image"), ]
def pre_transforms(self, data=None) -> Sequence[Callable]: return [ LoadImaged(keys="image"), AddChanneld(keys="image"), Spacingd(keys="image", pixdim=[1.0, 1.0, 1.0]), ScaleIntensityRanged(keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), EnsureTyped(keys="image"), ]
def get_click_transforms(): return Compose([ Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys=("image", "label", "pred")), FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanced( guidance="guidance", discrepancy="discrepancy", probability="probability", ), AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys=("image", "label")), ])
def val_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), reader="ITKReader"), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=self.target_spacing, mode=("bilinear", "nearest")), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), EnsureTyped(keys=("image", "label")), SelectItemsd(keys=("image", "label")), ]
def get_xforms(mode="train", keys=("image", "label")): """returns a composed transform for train/val/infer.""" xforms = [ LoadImaged(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[:len(keys)]), ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), ] if mode == "train": xforms.extend([ SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"), # ensure at least 192x192 RandAffined( keys, prob=0.15, rotate_range=( 0.05, 0.05, None ), # 3 parameters control the transform on 3 dimensions scale_range=(0.1, 0.1, None), mode=("bilinear", "nearest"), as_tensor_output=False, ), RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=3), RandGaussianNoised(keys[0], prob=0.15, std=0.01), RandFlipd(keys, spatial_axis=0, prob=0.5), RandFlipd(keys, spatial_axis=1, prob=0.5), RandFlipd(keys, spatial_axis=2, prob=0.5), ]) dtype = (np.float32, np.uint8) if mode == "val": dtype = (np.float32, np.uint8) if mode == "infer": dtype = (np.float32, ) xforms.extend([CastToTyped(keys, dtype=dtype), EnsureTyped(keys)]) return monai.transforms.Compose(xforms)
def pre_transforms(self, data=None): return [ LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8), FilterImaged(keys="image"), AsChannelFirstd(keys="image"), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), AddClickGuidanced(image="image", guidance="guidance"), AddGuidanceSignald(image="image", guidance="guidance", number_intensity_ch=3), EnsureTyped(keys="image", device=data.get("device") if data else None), ]
def val_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label")), AddChanneld(keys=("image", "label")), Spacingd( keys=("image", "label"), pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), ScaleIntensityRanged(keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=("image", "label"), source_key="image"), EnsureTyped(keys=("image", "label"), device=context.device), ]
def test_single_input(self): test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)] if torch.cuda.is_available(): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({ "data": test_data })["data"] self.assertTrue( isinstance( result, torch.Tensor if dtype == "tensor" else np.ndarray)) if isinstance(test_data, bool): self.assertFalse(result) else: assert_allclose(result, test_data, type_test=False) self.assertEqual(result.ndim, 0)
def get_pre_transforms(roi_size, model_size, dimensions): t = [ LoadImaged(keys=("image", "label")), AddChanneld(keys=("image", "label")), SpatialCropForegroundd(keys=("image", "label"), source_key="label", spatial_size=roi_size), Resized(keys=("image", "label"), spatial_size=model_size, mode=("area", "nearest")), NormalizeIntensityd(keys="image", subtrahend=208.0, divisor=388.0), ] if dimensions == 3: t.append(FindAllValidSlicesd(label="label", sids="sids")) t.extend([ AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), AddGuidanceSignald(image="image", guidance="guidance"), EnsureTyped(keys=("image", "label")), ]) return Compose(t)
def post_transforms(self, data=None) -> Sequence[Callable]: largest_cc = False if not data else data.get("largest_cc", False) applied_labels = list(self.labels.values()) if isinstance( self.labels, dict) else self.labels t = [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1), AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None), ] if largest_cc: t.append( KeepLargestConnectedComponentd(keys="pred", applied_labels=applied_labels)) t.extend([ ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), ]) return t
def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), reader="ITKReader"), NormalizeLabelsInDatasetd( keys="label", label_names=self._labels), # Specially for missing labels EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=self.target_spacing, mode=("bilinear", "nearest")), CropForegroundd(keys=("image", "label"), source_key="image"), SpatialPadd(keys=("image", "label"), spatial_size=self.spatial_size), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.spatial_size, pos=1, neg=1, num_samples=self.num_samples, image_key="image", image_threshold=0, ), EnsureTyped(keys=("image", "label"), device=context.device), RandFlipd(keys=("image", "label"), spatial_axis=[0], prob=0.10), RandFlipd(keys=("image", "label"), spatial_axis=[1], prob=0.10), RandFlipd(keys=("image", "label"), spatial_axis=[2], prob=0.10), RandRotate90d(keys=("image", "label"), prob=0.10, max_k=3), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), SelectItemsd(keys=("image", "label")), ]
def test_array_input(self): test_datas = [ np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]]) ] if torch.cuda.is_available(): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): result = EnsureTyped( keys="data", data_type=dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu")({ "data": test_data })["data"] if dtype == "NUMPY": self.assertTrue(result.dtype == np.float32) self.assertTrue( isinstance( result, torch.Tensor if dtype == "tensor" else np.ndarray)) assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2))
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], meta_keys=[ PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None ], orig_meta_keys=[ PostFix.meta("label"), PostFix.meta("label"), None ], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], meta_keys=[ PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1") ], orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], nearest_interp=[True, False], to_tensor=[True, True], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", PostFix.meta("image_inverted1"), PostFix.meta("image_inverted"), PostFix.meta("image"), "image_transforms", "label", "label_inverted", "label_inverted1", PostFix.meta("label_inverted1"), PostFix.meta("label_inverted"), PostFix.meta("label"), "label_transforms", "test_dict", "test_dict_transforms", ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue( isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue( isinstance(item["test_dict"]["filename_or_obj"], str)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def evaluate(args): # initialize Horovod library hvd.init() # Horovod limits CPU threads to be used per worker torch.set_num_threads(1) if hvd.local_rank() == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print(f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), EnsureTyped(keys=["img", "seg"]), ] ) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank()) # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent # issues with Infiniband implementations that are not fork-safe multiprocessing_context = None if hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods(): multiprocessing_context = "forkserver" # sliding window inference need to input 1 image in every iteration val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler, multiprocessing_context=multiprocessing_context, ) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{hvd.local_rank()}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) if hvd.rank() == 0: # load model parameters for evaluation model.load_state_dict(torch.load("final_model.pth")) # Horovod broadcasts parameters hvd.broadcast_parameters(model.state_dict(), root_rank=0) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if hvd.rank() == 0: print("evaluation metric:", metric)
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_samples): if mode != "test": keys = ["image", "label"] else: keys = ["image"] load_transforms = [ LoadImaged(keys=keys), EnsureChannelFirstd(keys=keys), ] # 2. sampling sample_transforms = [ PreprocessAnisotropic( keys=keys, clip_values=clip_values[task_id], pixdim=spacing[task_id], normalize_values=normalize_values[task_id], model_mode=mode, ), ] # 3. spatial transforms if mode == "train": other_transforms = [ SpatialPadd(keys=["image", "label"], spatial_size=patch_size[task_id]), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=patch_size[task_id], pos=pos_sample_num, neg=neg_sample_num, num_samples=num_samples, image_key="image", image_threshold=0, ), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("trilinear", "nearest"), align_corners=(True, None), prob=0.15, ), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] elif mode == "validation": other_transforms = [ CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] else: other_transforms = [ CastToTyped(keys=["image"], dtype=(np.float32)), EnsureTyped(keys=["image"]), ] all_transforms = load_transforms + sample_transforms + other_transforms return Compose(all_transforms)
def train(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f elif not os.path.exists(args.dir): # create 40 random image, mask paris for training print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), EnsureTyped(keys=["img", "seg"]), ]) # partition dataset based on current rank number, every rank trains with its own data # it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch data_part = partition_dataset( data=train_files, num_partitions=dist.get_world_size(), shuffle=True, even_divisible=True, )[dist.get_rank()] train_ds = SmartCacheDataset( data=data_part, transform=train_transforms, replace_rate=0.2, cache_num= 15, # we suppose to use 2 ranks in this example, every rank has 20 training images num_init_workers=2, num_replace_workers=2, ) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=True) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True).to(device) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) # start a typical PyTorch training epoch_loss_values = list() # start the replacement thread of SmartCache train_ds.start() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) # replace 20% of cache content for next epoch train_ds.update_cache() print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") # stop replacement thread of SmartCache train_ds.shutdown() print(f"train completed, epoch losses: {epoch_loss_values}") if dist.get_rank() == 0: # all processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes, # therefore, saving it in one process is sufficient torch.save(model.state_dict(), "final_model.pth") dist.destroy_process_group()