def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda:0") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice") # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg", name="evaluator", batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]), ) file_saver.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader(load_path="./runs/net_checkpoint_100.pth", load_dict={"net": net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) print(state) shutil.rmtree(tempdir)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ images = [ os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz" ]), os.sep.join([ "workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz" ]), ] # 2 binary labels for gender classification: man and woman labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64) # define transforms for image val_transforms = Compose( [ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) # define nifti dataset val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) # create DenseNet121 net = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") metric_name = "Accuracy" # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch[0], batch[1]), device, non_blocking) # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data prediction_saver = ClassificationSaver( output_dir="tempdir", batch_transform=lambda batch: batch[2], output_transform=lambda output: output[0].argmax(1), ) prediction_saver.attach(evaluator) # the model was trained by "densenet_training_array" example CheckpointLoader(load_path="./runs_array/net_checkpoint_20.pth", load_dict={ "net": net }).attach(evaluator) # create a validation data loader val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader) print(state)
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) file_saver.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) shutil.rmtree(tempdir)
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path="./runs/net_key_metric=0.9101.pth", load_dict={"net": net}), SegmentationSaver( output_dir="./runs/", batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers, ) evaluator.run() shutil.rmtree(tempdir)
def main(): """ Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using ignite to manage training and validation loop and checkpointing :return: """ """ Read input and configuration parameters """ parser = argparse.ArgumentParser( description='Run basic UNet with MONAI - Ignite version.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # training and validation params loss_type = config_info['training']['loss_type'] batch_size_train = config_info['training']['batch_size_train'] batch_size_valid = config_info['training']['batch_size_valid'] lr = float(config_info['training']['lr']) lr_decay = config_info['training']['lr_decay'] if lr_decay is not None: lr_decay = float(lr_decay) nr_train_epochs = config_info['training']['nr_train_epochs'] validation_every_n_epochs = config_info['training'][ 'validation_every_n_epochs'] sliding_window_validation = config_info['training'][ 'sliding_window_validation'] if 'model_to_load' in config_info['training'].keys(): model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise BlockingIOError( "cannot find model: {}".format(model_to_load)) else: model_to_load = None if 'manual_seed' in config_info['training'].keys(): seed = config_info['training']['manual_seed'] else: seed = None # data params data_root = config_info['data']['data_root'] training_list = config_info['data']['training_list'] validation_list = config_info['data']['validation_list'] # model saving out_model_dir = os.path.join( config_info['output']['out_model_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['output_subfix']) print("Saving to directory ", out_model_dir) if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') max_nr_models_saved = config_info['output']['max_nr_models_saved'] val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) if seed is not None: # set manual seed if required (both numpy and torch) set_determinism(seed=seed) # # set torch only seed # torch.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False """ Data Preparation """ # create cache directory to store results for Persistent Dataset persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) # create training and validation data lists train_files = create_data_list(data_folder_list=data_root, subject_list=training_list, img_postfix='_Image', label_postfix='_Label') print(len(train_files)) print(train_files[0]) print(train_files[-1]) val_files = create_data_list(data_folder_list=data_root, subject_list=validation_list, img_postfix='_Image', label_postfix='_Label') print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for training: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) # - define 2D patches to be extracted # - add data augmentation (random rotation and random flip) # - squeeze to 2D train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False), RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) # create a training data loader # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, # num_workers=num_workers) train_ds = monai.data.PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # check_train_data = monai.utils.misc.first(train_loader) # print("Training data tensor shapes") # print(check_train_data['img'].shape, check_train_data['seg'].shape) # data preprocessing for validation: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) if sliding_window_validation: val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), ToTensord(keys=['img', 'seg']) ]) do_shuffle = False collate_fn_to_use = None else: # - add extraction of 2D slices from validation set to emulate how loss is computed at training val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) do_shuffle = True collate_fn_to_use = list_data_collate # create a validation data loader # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, # num_workers=num_workers) val_ds = monai.data.PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=batch_size_valid, shuffle=do_shuffle, collate_fn=collate_fn_to_use, num_workers=num_workers) # check_valid_data = monai.utils.misc.first(val_loader) # print("Validation data tensor shapes") # print(check_valid_data['img'].shape, check_valid_data['seg'].shape) """ Network preparation """ # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) device = torch.cuda.current_device() if lr_decay is not None: lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=lr_decay, last_epoch=-1) """ Set ignite trainer """ # function to manage batch at training def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) trainer = create_supervised_trainer(model=net, optimizer=opt, loss_fn=loss_function, device=device, non_blocking=False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training if model_to_load is not None: checkpoint_handler = CheckpointLoader(load_path=model_to_load, load_dict={ 'net': net, 'opt': opt, }) checkpoint_handler.attach(trainer) state = trainer.state_dict() else: checkpoint_handler = ModelCheckpoint(out_model_dir, 'net', n_saved=max_nr_models_saved, require_empty=False) # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_train) train_tensorboard_stats_handler.attach(trainer) if lr_decay is not None: print("Using Exponential LR decay") lr_schedule_handler = LrScheduleHandler(lr_scheduler, print_lr=True, name="lr_scheduler", writer=writer_train) lr_schedule_handler.attach(trainer) """ Set ignite evaluator to perform validation at training """ # set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = { "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False), "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False) } def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch['img'].to(device), batch['seg'].to( device) roi_size = (96, 96, 1) seg_probs = sliding_window_inference(val_images, roi_size, batch_size_valid, net) return seg_probs, val_labels if sliding_window_validation: # use sliding window inference at validation print("3D evaluator is used") net.to(device) evaluator = Engine(_sliding_window_processor) for name, metric in val_metrics.items(): metric.attach(evaluator, name) else: # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values print("2D evaluator is used") evaluator = create_supervised_evaluator(model=net, metrics=val_metrics, device=device, non_blocking=True, prepare_batch=prepare_batch) epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator # early_stopper = EarlyStopping(patience=4, # score_function=stopping_fn_from_metric(metric_name), # trainer=trainer) # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid")) val_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_valid, output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.iteration ) # fetch global iteration number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. if val_image_to_tensorboad: val_tensorboard_image_handler = TensorBoardImageHandler( summary_writer=writer_valid, batch_transform=lambda batch: (batch['img'], batch['seg']), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch) evaluator.add_event_handler( event_name=Events.ITERATION_COMPLETED(every=1), handler=val_tensorboard_image_handler) """ Run training """ state = trainer.run(train_loader, nr_train_epochs) print("Done!")
def run_inference_test(root_dir, model_file, device=torch.device("cuda:0")): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_post_transforms = Compose([ Activationsd(keys="pred", output_postfix="act", sigmoid=True), AsDiscreted(keys="pred_act", output_postfix="dis", threshold_values=True), KeepLargestConnectedComponentd(keys="pred_act_dis", applied_values=[1], output_postfix=None), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), SegmentationSaver( output_dir=root_dir, batch_transform=lambda x: { "filename_or_obj": x["image.filename_or_obj"], "affine": x["image.affine"], "original_affine": x["image.original_affine"], "spatial_shape": x["image.spatial_shape"], }, output_transform=lambda x: x["pred_act_dis"], ), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred_act_dis"], x["label"])) }, additional_metrics={ "val_acc": Accuracy( output_transform=lambda x: (x["pred_act_dis"], x["label"])) }, val_handlers=val_handlers, ) evaluator.run() return evaluator.state.best_metric
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) device = torch.device("cuda:0") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch["img"].to(device), batch["seg"].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice") # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # convert the necessary metadata from batch data SegmentationSaver( output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg", name="evaluator", batch_transform=lambda batch: {"filename_or_obj": batch["img.filename_or_obj"], "affine": batch["img.affine"]}, output_transform=lambda output: predict_segmentation(output[0]), ).attach(evaluator) # the model was trained by "unet_training_dict" example CheckpointLoader(load_path="./runs/net_checkpoint_50.pth", load_dict={"net": net}).attach(evaluator) # sliding window inference for one image at every iteration val_loader = DataLoader( val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() ) state = evaluator.run(val_loader) print(state) shutil.rmtree(tempdir)
def evaluate(args): if args.local_rank == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print(f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed evaluation process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), ToTensord(keys=["image", "label"]), ] ) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, shuffle=False) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # wrap the model with DistributedDataParallel module net = DistributedDataParallel(net, device_ids=[args.local_rank]) val_post_transforms = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) val_handlers = [ CheckpointLoader( load_path="./runs/checkpoint_epoch=4.pth", load_dict={"net": net}, # config mapping to expected GPU device map_location={"cuda:0": f"cuda:{args.local_rank}"}, ), ] if dist.get_rank() == 0: logging.basicConfig(stream=sys.stdout, level=logging.INFO) val_handlers.extend( [ StatsHandler(output_transform=lambda x: None), SegmentationSaver( output_dir="./runs/", batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] ) evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), device=device, ) }, additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)}, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, ) evaluator.run() dist.destroy_process_group()
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ] ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_postprocessing = Compose( [ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` SaveImaged( keys="pred", meta_keys=PostFix.meta("image"), output_dir=root_dir, output_postfix="seg_transform" ), ] ) val_handlers = [ StatsHandler(iteration_log=False), CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), SegmentationSaver( output_dir=root_dir, output_postfix="seg_handler", batch_transform=from_engine(PostFix.meta("image")), output_transform=from_engine("pred"), ), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_postprocessing, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, val_handlers=val_handlers, amp=bool(amp), ) evaluator.run() return evaluator.state.best_metric
metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data prediction_saver = ClassificationSaver(output_dir='tempdir', name='evaluator', batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj']}, output_transform=lambda output: output[0].argmax(1)) prediction_saver.attach(evaluator) # the model was trained by "densenet_training_dict" example CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader)
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()]) segtrans = Compose([AddChannel(), EnsureType()]) ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg") def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)] val_data = decollate_batch(batch[2]) for seg_prob, data in zip(seg_probs, val_data): save_image(seg_prob, data) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice().attach(evaluator, "Mean_Dice") # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader(load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) print(state)