def get_network(properties, task_id, pretrain_path, checkpoint=None): n_class = len(properties["labels"]) in_channels = len(properties["modality"]) kernels, strides = get_kernels_strides(task_id) net = DynUNet( spatial_dims=3, in_channels=in_channels, out_channels=n_class, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=deep_supr_num[task_id], ) if checkpoint is not None: pretrain_path = os.path.join(pretrain_path, checkpoint) if os.path.exists(pretrain_path): net.load_state_dict(torch.load(pretrain_path)) print("pretrained checkpoint: {} loaded".format(pretrain_path)) else: print("no pretrained checkpoint") return net
def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) with torch.no_grad(): results = [net(torch.randn(input_shape).to(device)) ] + net.get_feature_maps() self.assertEqual(len(results), len(expected_shape)) for idx in range(len(results)): result, sub_expected_shape = results[idx], expected_shape[idx] self.assertEqual(result.shape, sub_expected_shape)
def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) if "alphadropout" in input_param.get("dropout"): self.assertTrue( any( isinstance(x, torch.nn.AlphaDropout) for x in net.modules())) with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape)
def build_nnunet(self): in_channels, out_channels, kernels, strides, self.patch_size = self.get_unet_params( ) self.n_class = out_channels - 1 if self.args.brats: out_channels = 3 self.model = DynUNet( self.args.dim, in_channels, out_channels, kernels, strides, strides[1:], filters=self.args.filters, norm_name=("INSTANCE", { "affine": True }), act_name=("leakyrelu", { "inplace": True, "negative_slope": 0.01 }), deep_supervision=self.args.deep_supervision, deep_supr_num=self.args.deep_supr_num, res_block=self.args.res_block, trans_bias=True, ) print0( f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}" )
def test_shape(self, input_param, input_data, expected_shape): net = DynUNet(**input_param) with torch.no_grad(): results = net(input_data) self.assertEqual(len(results), len(expected_shape)) for idx in range(len(results)): result, sub_expected_shape = results[idx], expected_shape[idx] self.assertEqual(result.shape, sub_expected_shape)
def test_consistency(self, input_param, input_shape, _): for eps in [1e-4, 1e-5]: for momentum in [0.1, 0.01]: for affine in [True, False]: norm_param = { "eps": eps, "momentum": momentum, "affine": affine } input_param["norm_name"] = ("instance", norm_param) input_param_fuser = input_param.copy() input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param) for memory_format in [ torch.contiguous_format, torch.channels_last_3d ]: net = DynUNet(**input_param).to( "cuda:0", memory_format=memory_format) net_fuser = DynUNet(**input_param_fuser).to( "cuda:0", memory_format=memory_format) net_fuser.load_state_dict(net.state_dict()) input_tensor = torch.randn(input_shape).to( "cuda:0", memory_format=memory_format) with eval_mode(net): result = net(input_tensor) with eval_mode(net_fuser): result_fuser = net_fuser(input_tensor) # torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14 if pytorch_after(1, 12): torch.testing.assert_close(result, result_fuser) else: torch.testing.assert_allclose(result, result_fuser)
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
def test_shape(self, input_param, input_data, expected_shape): net = DynUNet(**input_param) net.eval() with torch.no_grad(): result = net(input_data) self.assertEqual(result.shape, expected_shape)
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs): super().init(name, model_dir, conf, planner, **kwargs) # Multilabel # self.labels = { # "spleen": 1, # "right kidney": 2, # "left kidney": 3, # "liver": 6, # "stomach": 7, # "aorta": 8, # "inferior vena cava": 9, # "background": 0, # } # Single label self.labels = { "spleen": 1, "background": 0, } # Number of input channels - 4 for BRATS and 1 for spleen self.number_intensity_ch = 1 network = self.conf.get("network", "dynunet") # Model Files self.path = [ os.path.join(self.model_dir, f"pretrained_{self.name}_{network}.pt"), # pretrained os.path.join(self.model_dir, f"{self.name}_{network}.pt"), # published ] # Download PreTrained Model if strtobool(self.conf.get("use_pretrained_model", "true")): url = f"{self.PRE_TRAINED_PATH}/deepedit_{network}_singlelabel.pt" download_file(url, self.path[0]) self.target_spacing = (1.0, 1.0, 1.0) # target space for image self.spatial_size = (128, 128, 128) # train input size # Network if network == "unetr": self.network = UNETR( spatial_dims=3, in_channels=len(self.labels) + self.number_intensity_ch, out_channels=len(self.labels), img_size=self.spatial_size, feature_size=64, hidden_size=1536, mlp_dim=3072, num_heads=48, pos_embed="conv", norm_name="instance", res_block=True, ) else: self.network = DynUNet( spatial_dims=3, in_channels=len(self.labels) + self.number_intensity_ch, out_channels=len(self.labels), kernel_size=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 1]], upsample_kernel_size=[[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 1]], norm_name="instance", deep_supervision=False, res_block=True, )
def test_script(self): input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] net = DynUNet(**input_param) test_data = torch.randn(input_shape) test_script_save(net, test_data)
def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape)
def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) with torch.no_grad(): results = net(torch.randn(input_shape).to(device)) self.assertEqual(results.shape, expected_shape)
def run_inference(input_data, config_info): """ Pipeline to run inference with MONAI dynUNet model. The pipeline reads the input filenames, applies the required preprocessing and creates the pytorch dataloader; it then performs evaluation on each input file using a trained dynUNet model (random flipping augmentation is applied at inference). It uses the dynUNet model implemented in the MONAI framework (https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/nets/dynunet.py) which is inspired by the nnU-Net framework (https://arxiv.org/abs/1809.10486) Inference is performed in 2D slice-by-slice, all slices are then recombined together into the 3D volume. Args: input_data: str or list of strings, filenames of images to be processed config_info: dict, contains the configuration parameters to reload the trained model """ """ Read input and configuration parameters """ val_files = create_data_list_of_dictionaries(input_data) # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print("*** MONAI config: ") print_config() # print to log the parameter setups print("*** Network inference config: ") print(yaml.dump(config_info)) # inference params nr_out_channels = config_info['inference']['nr_out_channels'] spacing = config_info["inference"]["spacing"] prob_thr = config_info['inference']['probability_threshold'] model_to_load = config_info['inference']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError('Trained model not found') patch_size = config_info["inference"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}".format( torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) """ Data Preparation """ print("*** Preparing data ... ") # data preprocessing for inference: # - convert data to right format [batch, channel, dim, dim, dim] # - resample to the training resolution in-plane (not along z) # - apply whitening # - convert to tensor val_transforms = Compose([ LoadNiftid(keys=["image"]), AddChanneld(keys=["image"]), InPlaneSpacingd( keys=["image"], pixdim=spacing, mode="bilinear", ), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image"]), ]) # create a validation data loader val_ds = Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=config_info['device']['num_workers']) def prepare_batch(batchdata): assert isinstance(batchdata, dict), "prepare_batch expects dictionary input data." return ((batchdata["image"], batchdata["label"]) if "label" in batchdata else (batchdata["image"], None)) """ Network preparation """ print("*** Preparing network ... ") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = DynUNet(spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False).to(current_device) """ Set ignite evaluator to perform inference """ print("*** Preparing evaluator ... ") if nr_out_channels == 1: do_sigmoid = True do_softmax = False elif nr_out_channels > 1: do_sigmoid = False do_softmax = True else: raise Exception("incompatible number of output channels") print("Using sigmoid={} and softmax={} as final activation".format( do_sigmoid, do_softmax)) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=do_sigmoid, softmax=do_softmax), AsDiscreted(keys="pred", argmax=True, threshold_values=True, logit_thresh=prob_thr), KeepLargestConnectedComponentd(keys="pred", applied_labels=1) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=model_to_load, load_dict={"net": net}, map_location=torch.device('cpu')), SegmentationSaver( output_dir=config_info['output']['out_dir'], output_ext='.nii.gz', output_postfix=config_info['output']['out_postfix'], batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs = inputs.to(engine.state.device) if targets is not None: targets = targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2, )) flip_inputs_2 = torch.flip(inputs, dims=(3, )) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2, )) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3, )) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, prepare_batch=prepare_batch, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=val_post_transforms, val_handlers=val_handlers, amp=False, ) """ Run inference """ print("*** Running evaluator ... ") evaluator.run() print("Done!") return