def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) assert_allclose(result["pred"], output["pred"], rtol=1e-3) self.assertTupleEqual(result["pred"].shape, expected_shape) if "label" in result: assert_allclose(result["label"], output["label"], rtol=1e-3) self.assertTupleEqual(result["label"].shape, expected_shape)
def post_transforms(self, data=None): return [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), ToNumpyd(keys="pred"), RestoreLabeld(keys="pred", ref_image="image", mode="nearest"), AsChannelLastd(keys="pred"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), SqueezeDimd(keys="pred", dim=0), ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), ]
def train_post_transforms(self, context: Context): return [ Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1), AsDiscreted( keys=("pred", "label"), argmax=(True, False), to_onehot=(len(self.labels) + 1, len(self.labels) + 1), ), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), BoundingBoxd(keys="pred", result="result", bbox="bbox"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), ToNumpyd(keys="pred"), RestoreLabeld(keys="pred", ref_image="image", mode="nearest"), AsChannelLastd(keys="pred"), ]
def post_transforms(self, data=None) -> Sequence[Callable]: return [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1), AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None), SqueezeDimd(keys="pred", dim=0), ToNumpyd(keys=("image", "pred")), PostFilterLabeld(keys="pred", image="image"), FindContoursd(keys="pred", labels=self.labels), ]
def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) torch.testing.assert_allclose(result["pred_discrete"], output["pred_discrete"]) self.assertTupleEqual(result["pred_discrete"].shape, expected_shape) if "label_discrete" in result: torch.testing.assert_allclose(result["label_discrete"], output["label_discrete"]) self.assertTupleEqual(result["label_discrete"].shape, expected_shape)
def train_post_transforms(self, context: Context): return [ ToTensord(keys=("pred", "label")), Activationsd(keys="pred", softmax=True), AsDiscreted( keys=("pred", "label"), argmax=(True, False), to_onehot=True, n_classes=2, ), ]
def train_post_transforms(self, context: Context): return [ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=("pred", "label"), argmax=(True, False), to_onehot=(True, True), n_classes=len(self._labels), ), SplitPredsLabeld(keys="pred"), ]
def test_compute(self): data = [ { "image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"] }, { "image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"] }, ] handlers = [ DecollateBatch(event="MODEL_COMPLETED"), PostProcessing(transform=Compose([ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), ])), ] # set up engine, PostProcessing handler works together with postprocessing transforms of engine engine = SupervisedEvaluator( device=torch.device("cpu:0"), val_data_loader=data, epoch_length=2, network=torch.nn.PReLU(), # set decollate=False and execute some postprocessing first, then decollate in handlers postprocessing=lambda x: dict(pred=x["pred"] + 1.0), decollate=False, val_handlers=handlers, ) engine.run() expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]) for o, e in zip(engine.state.output, expected): torch.testing.assert_allclose(o["pred"], e) filename = o.get("filename_bak") if filename is not None: self.assertEqual(filename, "test2")
def get_click_transforms(self, context: Context): return [ Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True), ToNumpyd(keys=("image", "label", "pred")), # Transforms for click simulation FindDiscrepancyRegionsCustomd(keys="label", pred="pred", discrepancy="discrepancy"), AddRandomGuidanceCustomd( keys="NA", guidance="guidance", discrepancy="discrepancy", probability="probability", ), AddGuidanceSignalCustomd( keys="image", guidance="guidance", number_intensity_ch=self.number_intensity_ch), # ToTensord(keys=("image", "label")), ]
def post_transforms(self, data=None) -> Sequence[Callable]: largest_cc = False if not data else data.get("largest_cc", False) applied_labels = list(self.labels.values()) if isinstance( self.labels, dict) else self.labels t = [ EnsureTyped(keys="pred", device=data.get("device") if data else None), Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1), AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None), ] if largest_cc: t.append( KeepLargestConnectedComponentd(keys="pred", applied_labels=applied_labels)) t.extend([ ToNumpyd(keys="pred"), Restored(keys="pred", ref_image="image"), ]) return t
import torch from parameterized import parameterized from monai.engines import SupervisedEvaluator from monai.handlers import PostProcessing from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd # test lambda function as `transform` TEST_CASE_1 = [{"transform": lambda x: dict(pred=x["pred"] + 1.0)}, False, torch.tensor([[[[1.9975], [1.9997]]]])] # test composed postprocessing transforms as `transform` TEST_CASE_2 = [ { "transform": Compose( [ CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold=0.5, to_onehot=2), ] ), "event": "iteration_completed", }, True, torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]), ] class TestHandlerPostProcessing(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_compute(self, input_params, decollate, expected): data = [ {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]}, {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]},
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), EnsureTyped(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), Invertd( keys= "pred", # invert the `pred` data field, also support multiple fields transform=pre_transforms, orig_keys= "img", # get the previously applied pre_transforms information on the `img` data field, # then invert `pred` based on this information. we can use same info # for multiple fields, also support different orig_keys for different fields meta_keys= "pred_meta_dict", # key field to save inverted meta data, every item maps to `keys` orig_meta_keys= "img_meta_dict", # get the meta data from `img_meta_dict` field when inverting, # for example, may need the `affine` to invert `Spacingd` transform, # multiple fields can use the same meta data to invert meta_key_postfix= "meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key, # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}", # otherwise, no need this arg during inverting nearest_interp= False, # don't change the interpolation mode to "nearest" when inverting transforms # to ensure a smooth output, then execute `AsDiscreted` transform to_tensor=True, # convert to PyTorch Tensor after inverting ), AsDiscreted(keys="pred", threshold=0.5), SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # decollate the batch data into a list of dictionaries, then execute postprocessing transforms d = [post_transforms(i) for i in decollate_batch(d)]
def get_post_transforms(): return Compose([ Activationsd(keys='pred', sigmoid=True), AsDiscreted(keys='pred', threshold_values=True, logit_thresh=0.5) ])
def run_inference(input_data, config_info): """ Pipeline to run inference with MONAI dynUNet model. The pipeline reads the input filenames, applies the required preprocessing and creates the pytorch dataloader; it then performs evaluation on each input file using a trained dynUNet model (random flipping augmentation is applied at inference). It uses the dynUNet model implemented in the MONAI framework (https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/nets/dynunet.py) which is inspired by the nnU-Net framework (https://arxiv.org/abs/1809.10486) Inference is performed in 2D slice-by-slice, all slices are then recombined together into the 3D volume. Args: input_data: str or list of strings, filenames of images to be processed config_info: dict, contains the configuration parameters to reload the trained model """ """ Read input and configuration parameters """ val_files = create_data_list_of_dictionaries(input_data) # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print("*** MONAI config: ") print_config() # print to log the parameter setups print("*** Network inference config: ") print(yaml.dump(config_info)) # inference params nr_out_channels = config_info['inference']['nr_out_channels'] spacing = config_info["inference"]["spacing"] prob_thr = config_info['inference']['probability_threshold'] model_to_load = config_info['inference']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError('Trained model not found') patch_size = config_info["inference"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}".format( torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) """ Data Preparation """ print("*** Preparing data ... ") # data preprocessing for inference: # - convert data to right format [batch, channel, dim, dim, dim] # - resample to the training resolution in-plane (not along z) # - apply whitening # - convert to tensor val_transforms = Compose([ LoadNiftid(keys=["image"]), AddChanneld(keys=["image"]), InPlaneSpacingd( keys=["image"], pixdim=spacing, mode="bilinear", ), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image"]), ]) # create a validation data loader val_ds = Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=config_info['device']['num_workers']) def prepare_batch(batchdata): assert isinstance(batchdata, dict), "prepare_batch expects dictionary input data." return ((batchdata["image"], batchdata["label"]) if "label" in batchdata else (batchdata["image"], None)) """ Network preparation """ print("*** Preparing network ... ") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = DynUNet(spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False).to(current_device) """ Set ignite evaluator to perform inference """ print("*** Preparing evaluator ... ") if nr_out_channels == 1: do_sigmoid = True do_softmax = False elif nr_out_channels > 1: do_sigmoid = False do_softmax = True else: raise Exception("incompatible number of output channels") print("Using sigmoid={} and softmax={} as final activation".format( do_sigmoid, do_softmax)) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=do_sigmoid, softmax=do_softmax), AsDiscreted(keys="pred", argmax=True, threshold_values=True, logit_thresh=prob_thr), KeepLargestConnectedComponentd(keys="pred", applied_labels=1) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=model_to_load, load_dict={"net": net}, map_location=torch.device('cpu')), SegmentationSaver( output_dir=config_info['output']['out_dir'], output_ext='.nii.gz', output_postfix=config_info['output']['out_postfix'], batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs = inputs.to(engine.state.device) if targets is not None: targets = targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2, )) flip_inputs_2 = torch.flip(inputs, dims=(3, )) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2, )) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3, )) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, prepare_batch=prepare_batch, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=val_post_transforms, val_handlers=val_handlers, amp=False, ) """ Run inference """ print("*** Running evaluator ... ") evaluator.run() print("Done!") return
def train(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 40 random image, mask paris for training print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = Dataset(data=train_files, transform=train_transforms) # create a training data sampler train_sampler = DistributedSampler(train_ds) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=False, num_workers=2, pin_memory=True, sampler=train_sampler, ) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) # wrap the model with DistributedDataParallel module net = DistributedDataParallel(net, device_ids=[device]) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ] if dist.get_rank() == 0: logging.basicConfig(stream=sys.stdout, level=logging.INFO) train_handlers.extend([ StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2), ]) trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, post_transform=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device) }, train_handlers=train_handlers, ) trainer.run() dist.destroy_process_group()
def evaluate(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print(f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed evaluation process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ] ) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, shuffle=False) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # wrap the model with DistributedDataParallel module net = DistributedDataParallel(net, device_ids=[device]) val_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) val_handlers = [ CheckpointLoader( load_path="./runs/checkpoint_epoch=4.pt", load_dict={"net": net}, # config mapping to expected GPU device map_location={"cuda:0": f"cuda:{args.local_rank}"}, ), ] if dist.get_rank() == 0: logging.basicConfig(stream=sys.stdout, level=logging.INFO) val_handlers.extend( [ StatsHandler(output_transform=lambda x: None), SegmentationSaver( output_dir="./runs/", batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] ) evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), device=device, ) }, additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)}, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) evaluator.run() dist.destroy_process_group()
def configure(self): self.set_device() network = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(self.device) if self.multi_gpu: network = DistributedDataParallel( module=network, device_ids=[self.device], find_unused_parameters=False, ) train_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=[1.0, 1.0, 1.0], mode=["bilinear", "nearest"]), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=("image", "label")), ]) train_datalist = load_decathlon_datalist(self.data_list_file_path, True, "training") if self.multi_gpu: train_datalist = partition_dataset( data=train_datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=train_datalist, transform=train_transforms, cache_num=32, cache_rate=1.0, num_workers=4, ) train_data_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, ) val_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), ToTensord(keys=("image", "label")), ]) val_datalist = load_decathlon_datalist(self.data_list_file_path, True, "validation") val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4) val_data_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=4, ) post_transform = Compose([ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=["pred", "label"], argmax=[True, False], to_onehot=True, n_classes=2, ), ]) # metric key_val_metric = { "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), device=self.device, ) } val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver( save_dir=self.ckpt_dir, save_dict={"model": network}, save_key_metric=True, ), TensorBoardStatsHandler(log_dir=self.ckpt_dir, output_transform=lambda x: None), ] self.eval_engine = SupervisedEvaluator( device=self.device, val_data_loader=val_data_loader, network=network, inferer=SlidingWindowInferer( roi_size=[160, 160, 160], sw_batch_size=4, overlap=0.5, ), post_transform=post_transform, key_val_metric=key_val_metric, val_handlers=val_handlers, amp=self.amp, ) optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) loss_function = DiceLoss(to_onehot_y=True, softmax=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=self.eval_engine, interval=self.val_interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( log_dir=self.ckpt_dir, tag_name="train_loss", output_transform=lambda x: x["loss"], ), ] self.train_engine = SupervisedTrainer( device=self.device, max_epochs=self.max_epochs, train_data_loader=train_data_loader, network=network, optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, key_train_metric=None, train_handlers=train_handlers, amp=self.amp, ) if self.local_rank > 0: self.train_engine.logger.setLevel(logging.WARNING) self.eval_engine.logger.setLevel(logging.WARNING)
def train(data_folder=".", model_folder="runs", continue_training=False): """run a training pipeline.""" #/== files for synthesis path_parent = Path( '/content/drive/My Drive/Datasets/covid19/COVID-19-20_augs_cea/') path_synthesis = Path( path_parent / 'CeA_BASE_grow=1_bg=-1.00_step=-1.0_scale=-1.0_seed=1.0_ch0_1=-1_ch1_16=-1_ali_thr=0.1' ) scans_syns = os.listdir(path_synthesis) decreasing_sequence = get_decreasing_sequence(255, splits=20) keys2 = ("image", "label", "synthetic_lesion") # READ THE SYTHETIC HEALTHY TEXTURE path_synthesis_old = '/content/drive/My Drive/Datasets/covid19/results/cea_synthesis/patient0/' texture_orig = np.load(f'{path_synthesis_old}texture.npy.npz') texture_orig = texture_orig.f.arr_0 texture = texture_orig + np.abs(np.min(texture_orig)) + .07 texture = np.pad(texture, ((100, 100), (100, 100)), mode='reflect') print(f'type(texture) = {type(texture)}, {np.shape(texture)}') #==/ images = sorted(glob.glob(os.path.join(data_folder, "*_ct.nii.gz"))[:10]) #OMM labels = sorted(glob.glob(os.path.join(data_folder, "*_seg.nii.gz"))[:10]) #OMM logging.info( f"training: image/label ({len(images)}) folder: {data_folder}") amp = True # auto. mixed precision keys = ("image", "label") train_frac, val_frac = 0.8, 0.2 n_train = int(train_frac * len(images)) + 1 n_val = min(len(images) - n_train, int(val_frac * len(images))) logging.info( f"training: train {n_train} val {n_val}, folder: {data_folder}") train_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[:n_train], labels[:n_train])] val_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[-n_val:], labels[-n_val:])] # create a training data loader batch_size = 1 # XX was 2 logging.info(f"batch size {batch_size}") train_transforms = get_xforms("synthesis", keys, keys2, path_synthesis, decreasing_sequence, scans_syns, texture) train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms) train_loader = monai.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available(), # collate_fn=pad_list_data_collate, ) # create a validation data loader val_transforms = get_xforms("val", keys) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader( val_ds, batch_size= 1, # image-level batch to the sliding window method, not the window-level batch num_workers=2, pin_memory=torch.cuda.is_available(), ) # create BasicUNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = get_net().to(device) # if continue training if continue_training: ckpts = sorted(glob.glob(os.path.join(model_folder, "*.pt"))) ckpt = ckpts[-1] logging.info(f"continue training using {ckpt}.") net.load_state_dict(torch.load(ckpt, map_location=device)) # max_epochs, lr, momentum = 500, 1e-4, 0.95 max_epochs, lr, momentum = 20, 1e-4, 0.95 #OMM logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}") opt = torch.optim.Adam(net.parameters(), lr=lr) # create evaluator (to be used to measure model quality during training val_post_transform = monai.transforms.Compose([ AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=True, n_classes=2) ]) val_handlers = [ ProgressBar(), MetricsSaver(save_dir="./metrics_val", metrics="*"), CheckpointSaver(save_dir=model_folder, save_dict={"net": net}, save_key_metric=True, key_metric_n_saved=6), ] evaluator = monai.engines.SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=get_inferer(), post_transform=val_post_transform, key_val_metric={ "val_mean_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, amp=amp, ) # evaluator as an event handler of the trainer train_handlers = [ ValidationHandler(validator=evaluator, interval=1, epoch_level=True), # MetricsSaver(save_dir="./metrics_train", metrics="*"), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), ] trainer = monai.engines.SupervisedTrainer( device=device, max_epochs=max_epochs, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=DiceCELoss(), inferer=get_inferer(), key_train_metric=None, train_handlers=train_handlers, amp=amp, ) trainer.run()
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), SegmentationSaver( output_dir=root_dir, batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, amp=True if amp else False, ) evaluator.run() return evaluator.state.best_metric
from monai.handlers import PostProcessing from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd # test lambda function as `transform` TEST_CASE_1 = [{ "transform": lambda x: dict(pred=x["pred"] + 1.0) }, torch.tensor([[[[1.9975], [1.9997]]]])] # test composed post transforms as `transform` TEST_CASE_2 = [ { "transform": Compose([ CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), ]) }, torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]), ] class TestHandlerPostProcessing(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_compute(self, input_params, expected): data = [ { "image": torch.tensor([[[[2.0], [3.0]]]]), "filename": "test1" },
def train(data_folder=".", model_folder="runs"): """run a training pipeline.""" images = sorted(glob.glob(os.path.join(data_folder, "*_ct.nii.gz"))) labels = sorted(glob.glob(os.path.join(data_folder, "*_seg.nii.gz"))) logging.info( f"training: image/label ({len(images)}) folder: {data_folder}") amp = True # auto. mixed precision keys = ("image", "label") train_frac, val_frac = 0.8, 0.2 n_train = int(train_frac * len(images)) + 1 n_val = min(len(images) - n_train, int(val_frac * len(images))) logging.info( f"training: train {n_train} val {n_val}, folder: {data_folder}") train_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[:n_train], labels[:n_train])] val_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[-n_val:], labels[-n_val:])] # create a training data loader batch_size = 2 logging.info(f"batch size {batch_size}") train_transforms = get_xforms("train", keys) train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms) train_loader = monai.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_transforms = get_xforms("val", keys) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader( val_ds, batch_size= 1, # image-level batch to the sliding window method, not the window-level batch num_workers=2, pin_memory=torch.cuda.is_available(), ) # create BasicUNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = get_net().to(device) max_epochs, lr, momentum = 500, 1e-4, 0.95 logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}") opt = torch.optim.Adam(net.parameters(), lr=lr) # create evaluator (to be used to measure model quality during training val_post_transform = monai.transforms.Compose([ AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=True, n_classes=2) ]) val_handlers = [ ProgressBar(), CheckpointSaver(save_dir=model_folder, save_dict={"net": net}, save_key_metric=True, key_metric_n_saved=3), ] evaluator = monai.engines.SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=get_inferer(), post_transform=val_post_transform, key_val_metric={ "val_mean_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, amp=amp, ) # evaluator as an event handler of the trainer train_handlers = [ ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), ] trainer = monai.engines.SupervisedTrainer( device=device, max_epochs=max_epochs, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=DiceCELoss(), inferer=get_inferer(), key_train_metric=None, train_handlers=train_handlers, amp=amp, ) trainer.run()
def run_training_test(root_dir, device="cuda:0", amp=False): images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None), TensorBoardImageHandler(log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"]), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, amp=True if amp else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=root_dir, save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, amp=True if amp else False, ) trainer.run() return evaluator.state.best_metric
def train_post_transforms(self, context: Context): return [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), ]
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) ################################ DATASET ################################ # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ] ) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) ################################ DATASET ################################ ################################ NETWORK ################################ # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) ################################ NETWORK ################################ ################################ LOSS ################################ loss = monai.losses.DiceLoss(sigmoid=True) ################################ LOSS ################################ ################################ OPT ################################ opt = torch.optim.Adam(net.parameters(), 1e-3) ################################ OPT ################################ ################################ LR ################################ lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) ################################ LR ################################ val_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), TensorBoardImageHandler( log_dir="./runs/", batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], ), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), post_transform=train_post_transforms, key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
def train(index): # ---------- Build the nn-Unet network ------------ if opt.resolution is None: sizes, spacings = opt.patch_size, opt.spacing else: sizes, spacings = opt.patch_size, opt.resolution strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = monai.networks.nets.DynUNet( spatial_dims=3, in_channels=opt.in_channels, out_channels=opt.out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], res_block=True, # act=act_type, # norm=Norm.BATCH, ).to(device) from torch.autograd import Variable from torchsummaryX import summary data = Variable( torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda() out = net(data) summary(net, data) print("out size: {}".format(out.size())) # if opt.preload is not None: # net.load_state_dict(torch.load(opt.preload)) # ---------- ------------------------ ------------ optim = torch.optim.Adam(net.parameters(), lr=opt.lr) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9) loss_function = monai.losses.DiceCELoss(sigmoid=True) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loaders[index], network=net, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ ValidationHandler(validator=evaluator, interval=5, epoch_level=True), LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": optim }, save_final=True, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=opt.epochs, train_data_loader=train_loaders[index], network=net, optimizer=optim, loss_function=loss_function, inferer=SimpleInferer(), post_transform=train_post_transforms, amp=False, train_handlers=train_handlers, ) trainer.run() return net
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)] # model file path model_file = glob("./runs/net_key_metric*")[0] # define transforms for image and segmentation val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ] ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=model_file, load_dict={"net": net}), SegmentationSaver( output_dir="./runs/", batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) evaluator.run()
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) set_determinism(seed=0) logging.basicConfig(stream=sys.stdout, level=logging.INFO) device = torch.device(opt.gpu_id) # ------- Data loader creation ---------- # images images = sorted(glob(os.path.join(opt.images_folder, 'image*.nii'))) segs = sorted(glob(os.path.join(opt.labels_folder, 'label*.nii'))) train_files = [] val_files = [] for i in range(opt.models_ensemble): train_files.append([{ "image": img, "label": seg } for img, seg in zip( images[:(opt.split_val * i)] + images[(opt.split_val * (i + 1)):(len(images) - opt.split_val)], segs[:(opt.split_val * i)] + segs[(opt.split_val * (i + 1)):(len(images) - opt.split_val)])]) val_files.append([{ "image": img, "label": seg } for img, seg in zip( images[(opt.split_val * i):(opt.split_val * (i + 1))], segs[(opt.split_val * i):(opt.split_val * (i + 1))])]) test_files = [{ "image": img, "label": seg } for img, seg in zip(images[(len(images) - opt.split_test):len(images)], segs[( len(images) - opt.split_test):len(images)])] # ----------- Transforms list -------------- if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # ---------- Creation of DataLoaders ------------- train_dss = [ CacheDataset(data=train_files[i], transform=train_transforms) for i in range(opt.models_ensemble) ] train_loaders = [ DataLoader(train_dss[i], batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) for i in range(opt.models_ensemble) ] val_dss = [ CacheDataset(data=val_files[i], transform=val_transforms) for i in range(opt.models_ensemble) ] val_loaders = [ DataLoader(val_dss[i], batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) for i in range(opt.models_ensemble) ] test_ds = CacheDataset(data=test_files, transform=val_transforms) test_loader = DataLoader(test_ds, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) def train(index): # ---------- Build the nn-Unet network ------------ if opt.resolution is None: sizes, spacings = opt.patch_size, opt.spacing else: sizes, spacings = opt.patch_size, opt.resolution strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = monai.networks.nets.DynUNet( spatial_dims=3, in_channels=opt.in_channels, out_channels=opt.out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], res_block=True, # act=act_type, # norm=Norm.BATCH, ).to(device) from torch.autograd import Variable from torchsummaryX import summary data = Variable( torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda() out = net(data) summary(net, data) print("out size: {}".format(out.size())) # if opt.preload is not None: # net.load_state_dict(torch.load(opt.preload)) # ---------- ------------------------ ------------ optim = torch.optim.Adam(net.parameters(), lr=opt.lr) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9) loss_function = monai.losses.DiceCELoss(sigmoid=True) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loaders[index], network=net, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ ValidationHandler(validator=evaluator, interval=5, epoch_level=True), LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": optim }, save_final=True, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=opt.epochs, train_data_loader=train_loaders[index], network=net, optimizer=optim, loss_function=loss_function, inferer=SimpleInferer(), post_transform=train_post_transforms, amp=False, train_handlers=train_handlers, ) trainer.run() return net models = [train(i) for i in range(opt.models_ensemble)] # -------- Test the models --------- def ensemble_evaluate(post_transforms, models): evaluator = EnsembleEvaluator( device=device, val_data_loader=test_loader, pred_keys=opt.pred_keys, networks=models, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=post_transforms, key_val_metric={ "test_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, ) evaluator.run() mean_post_transforms = Compose([ MeanEnsembled( keys=opt.pred_keys, output_key="pred", # in this particular example, we use validation metrics as weights weights=opt.weights_models, ), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) print('Results from MeanEnsembled:') ensemble_evaluate(mean_post_transforms, models) vote_post_transforms = Compose([ Activationsd(keys=opt.pred_keys, sigmoid=True), # transform data into discrete before voting AsDiscreted(keys=opt.pred_keys, threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), VoteEnsembled(keys=opt.pred_keys, output_key="pred"), ]) print('Results from VoteEnsembled:') ensemble_evaluate(vote_post_transforms, models)