예제 #1
0
def get_network(properties, task_id, pretrain_path, checkpoint=None):
    n_class = len(properties["labels"])
    in_channels = len(properties["modality"])
    kernels, strides = get_kernels_strides(task_id)

    net = DynUNet(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=n_class,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=deep_supr_num[task_id],
    )

    if checkpoint is not None:
        pretrain_path = os.path.join(pretrain_path, checkpoint)
        if os.path.exists(pretrain_path):
            net.load_state_dict(torch.load(pretrain_path))
            print("pretrained checkpoint: {} loaded".format(pretrain_path))
        else:
            print("no pretrained checkpoint")
    return net
예제 #2
0
 def test_shape(self, input_param, input_shape, expected_shape):
     net = DynUNet(**input_param).to(device)
     with torch.no_grad():
         results = [net(torch.randn(input_shape).to(device))
                    ] + net.get_feature_maps()
         self.assertEqual(len(results), len(expected_shape))
         for idx in range(len(results)):
             result, sub_expected_shape = results[idx], expected_shape[idx]
             self.assertEqual(result.shape, sub_expected_shape)
예제 #3
0
 def test_shape(self, input_param, input_shape, expected_shape):
     net = DynUNet(**input_param).to(device)
     if "alphadropout" in input_param.get("dropout"):
         self.assertTrue(
             any(
                 isinstance(x, torch.nn.AlphaDropout)
                 for x in net.modules()))
     with eval_mode(net):
         result = net(torch.randn(input_shape).to(device))
         self.assertEqual(result.shape, expected_shape)
예제 #4
0
    def build_nnunet(self):
        in_channels, out_channels, kernels, strides, self.patch_size = self.get_unet_params(
        )
        self.n_class = out_channels - 1
        if self.args.brats:
            out_channels = 3

        self.model = DynUNet(
            self.args.dim,
            in_channels,
            out_channels,
            kernels,
            strides,
            strides[1:],
            filters=self.args.filters,
            norm_name=("INSTANCE", {
                "affine": True
            }),
            act_name=("leakyrelu", {
                "inplace": True,
                "negative_slope": 0.01
            }),
            deep_supervision=self.args.deep_supervision,
            deep_supr_num=self.args.deep_supr_num,
            res_block=self.args.res_block,
            trans_bias=True,
        )
        print0(
            f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}"
        )
예제 #5
0
 def test_shape(self, input_param, input_data, expected_shape):
     net = DynUNet(**input_param)
     with torch.no_grad():
         results = net(input_data)
         self.assertEqual(len(results), len(expected_shape))
         for idx in range(len(results)):
             result, sub_expected_shape = results[idx], expected_shape[idx]
             self.assertEqual(result.shape, sub_expected_shape)
예제 #6
0
    def test_consistency(self, input_param, input_shape, _):
        for eps in [1e-4, 1e-5]:
            for momentum in [0.1, 0.01]:
                for affine in [True, False]:
                    norm_param = {
                        "eps": eps,
                        "momentum": momentum,
                        "affine": affine
                    }
                    input_param["norm_name"] = ("instance", norm_param)
                    input_param_fuser = input_param.copy()
                    input_param_fuser["norm_name"] = ("instance_nvfuser",
                                                      norm_param)
                    for memory_format in [
                            torch.contiguous_format, torch.channels_last_3d
                    ]:
                        net = DynUNet(**input_param).to(
                            "cuda:0", memory_format=memory_format)
                        net_fuser = DynUNet(**input_param_fuser).to(
                            "cuda:0", memory_format=memory_format)
                        net_fuser.load_state_dict(net.state_dict())

                        input_tensor = torch.randn(input_shape).to(
                            "cuda:0", memory_format=memory_format)
                        with eval_mode(net):
                            result = net(input_tensor)
                        with eval_mode(net_fuser):
                            result_fuser = net_fuser(input_tensor)

                        # torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14
                        if pytorch_after(1, 12):
                            torch.testing.assert_close(result, result_fuser)
                        else:
                            torch.testing.assert_allclose(result, result_fuser)
예제 #7
0
def run_training(train_file_list, valid_file_list, config_info):
    """
    Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks:
        * Data Preparation: Extract the filenames and prepare the training/validation processing transforms
        * Load Data: Load training and validation data to PyTorch DataLoader
        * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler
        * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation
            during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach
            on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric.
        * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop.
        * Run training: The MONAI trainer is run, performing training and validation during training.
    Args:
        train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format.
        valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format.
        config_info: dict, contains configuration parameters for sampling, network and training.
            See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields.
    """

    """
    Read input and configuration parameters
    """
    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # extract network parameters, perform checks/set defaults if not present and print them to log
    if 'seg_labels' in config_info['training'].keys():
        seg_labels = config_info['training']['seg_labels']
    else:
        seg_labels = [1]
    nr_out_channels = len(seg_labels)
    print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels))
    patch_size = config_info["training"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    spacing = config_info["training"]["spacing"]
    print("Bringing all images to spacing = {}".format(spacing))

    if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None:
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise FileNotFoundError("Cannot find model: {}".format(model_to_load))
        else:
            print("Loading model from {}".format(model_to_load))
    else:
        model_to_load = None

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))

    # set determinism if required
    if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None:
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    if seed is not None:
        print("Using determinism with seed = {}\n".format(seed))
        set_determinism(seed=seed)

    """
    Setup data output directory
    """
    out_model_dir = os.path.join(config_info['output']['out_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['out_postfix'])
    print("Saving to directory {}\n".format(out_model_dir))
    # create cache directory to store results for Persistent Dataset
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    """
    Data preparation
    """
    # Read the input files for training and validation
    print("*** Loading input data for training...")

    train_files = create_data_list_of_dictionaries(train_file_list)
    print("Number of inputs for training = {}".format(len(train_files)))

    val_files = create_data_list_of_dictionaries(valid_file_list)
    print("Number of inputs for validation = {}".format(len(val_files)))

    # Define MONAI processing transforms for the training data. This includes:
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1]
    # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M])
    # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd,
    #       RandFlipd)
    # - ToTensor: convert to pytorch tensor
    train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                        mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # Define MONAI processing transforms for the validation data
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - ToTensor: convert to pytorch tensor
    # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )

    """
    Load data 
    """
    # create training data loader
    train_ds = PersistentDataset(data=train_files, transform=train_transforms,
                                 cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=config_info['training']['batch_size_train'],
                              shuffle=True,
                              num_workers=config_info['device']['num_workers'])
    check_train_data = misc.first(train_loader)
    print("Training data tensor shapes:")
    print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape))

    # create validation data loader
    if config_info['training']['batch_size_valid'] != 1:
        raise Exception("Batch size different from 1 at validation ar currently not supported")
    val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config_info['device']['num_workers'])
    check_valid_data = misc.first(val_loader)
    print("Validation data tensor shapes (Example):")
    print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape))

    """
    Network preparation
    """
    print("*** Preparing the network ...")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # initialise the network
    net = DynUNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=nr_out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
        res_block=False,
    ).to(current_device)
    print(net)

    # define the loss function
    loss_function = choose_loss_function(nr_out_channels, config_info)

    # define the optimiser and the learning rate scheduler
    opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9
    )

    """
    MONAI evaluator
    """
    print("*** Preparing the dynUNet evaluator engine...\n")
    # val_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                output_transform=lambda x: None,
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True,
                        file_prefix='best_valid'),
    ]
    if config_info['output']['val_image_to_tensorboad']:
        val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                                    batch_transform=lambda x: (x["image"], x["label"]),
                                                    output_transform=lambda x: x["pred"], interval=2))

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2,))
            flip_inputs_2 = torch.flip(inputs, dims=(3,))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0),
        post_transform=None,
        key_val_metric={
            "Mean_dice": MeanDice(
                include_background=False,
                to_onehot_y=True,
                mutually_exclusive=True,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=False,
    )

    """
    MONAI trainer
    """
    print("*** Preparing the dynUNet trainer engine...\n")
    # train_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )

    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    # define event handlers for the trainer
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(summary_writer=writer_train,
                                log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss",
                                output_transform=lambda x: x["loss"],
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt},
                        save_final=True,
                        save_interval=2, epoch_level=True,
                        n_saved=config_info['output']['max_nr_models_saved']),
    ]
    if model_to_load is not None:
        train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt}))

    # define customized trainer
    class DynUNetTrainer(SupervisedTrainer):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)

            def _compute_loss(preds, label):
                labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]]
                return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))])

            self.network.train()
            self.optimizer.zero_grad()
            if self.amp and self.scaler is not None:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network)
                    loss = _compute_loss(predictions, targets)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.inferer(inputs, self.network)
                loss = _compute_loss(predictions, targets).mean()
                loss.backward()
                self.optimizer.step()
            return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()}

    trainer = DynUNetTrainer(
        device=current_device,
        max_epochs=config_info['training']['nr_train_epochs'],
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=False,
    )

    """
    Run training
    """
    print("*** Run training...")
    trainer.run()
    print("Done!")
예제 #8
0
 def test_shape(self, input_param, input_data, expected_shape):
     net = DynUNet(**input_param)
     net.eval()
     with torch.no_grad():
         result = net(input_data)
         self.assertEqual(result.shape, expected_shape)
예제 #9
0
    def init(self, name: str, model_dir: str, conf: Dict[str, str],
             planner: Any, **kwargs):
        super().init(name, model_dir, conf, planner, **kwargs)

        # Multilabel
        # self.labels = {
        #     "spleen": 1,
        #     "right kidney": 2,
        #     "left kidney": 3,
        #     "liver": 6,
        #     "stomach": 7,
        #     "aorta": 8,
        #     "inferior vena cava": 9,
        #     "background": 0,
        # }

        # Single label
        self.labels = {
            "spleen": 1,
            "background": 0,
        }

        # Number of input channels - 4 for BRATS and 1 for spleen
        self.number_intensity_ch = 1

        network = self.conf.get("network", "dynunet")

        # Model Files
        self.path = [
            os.path.join(self.model_dir,
                         f"pretrained_{self.name}_{network}.pt"),  # pretrained
            os.path.join(self.model_dir,
                         f"{self.name}_{network}.pt"),  # published
        ]

        # Download PreTrained Model
        if strtobool(self.conf.get("use_pretrained_model", "true")):
            url = f"{self.PRE_TRAINED_PATH}/deepedit_{network}_singlelabel.pt"
            download_file(url, self.path[0])

        self.target_spacing = (1.0, 1.0, 1.0)  # target space for image
        self.spatial_size = (128, 128, 128)  # train input size

        # Network
        if network == "unetr":
            self.network = UNETR(
                spatial_dims=3,
                in_channels=len(self.labels) + self.number_intensity_ch,
                out_channels=len(self.labels),
                img_size=self.spatial_size,
                feature_size=64,
                hidden_size=1536,
                mlp_dim=3072,
                num_heads=48,
                pos_embed="conv",
                norm_name="instance",
                res_block=True,
            )
        else:
            self.network = DynUNet(
                spatial_dims=3,
                in_channels=len(self.labels) + self.number_intensity_ch,
                out_channels=len(self.labels),
                kernel_size=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3],
                             [3, 3, 3], [3, 3, 3]],
                strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2],
                         [2, 2, 1]],
                upsample_kernel_size=[[2, 2, 2], [2, 2, 2], [2, 2, 2],
                                      [2, 2, 2], [2, 2, 1]],
                norm_name="instance",
                deep_supervision=False,
                res_block=True,
            )
예제 #10
0
 def test_script(self):
     input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0]
     net = DynUNet(**input_param)
     test_data = torch.randn(input_shape)
     test_script_save(net, test_data)
예제 #11
0
 def test_shape(self, input_param, input_shape, expected_shape):
     net = DynUNet(**input_param).to(device)
     with eval_mode(net):
         result = net(torch.randn(input_shape).to(device))
         self.assertEqual(result.shape, expected_shape)
예제 #12
0
 def test_shape(self, input_param, input_shape, expected_shape):
     net = DynUNet(**input_param).to(device)
     with torch.no_grad():
         results = net(torch.randn(input_shape).to(device))
         self.assertEqual(results.shape, expected_shape)
예제 #13
0
def run_inference(input_data, config_info):
    """
    Pipeline to run inference with MONAI dynUNet model. The pipeline reads the input filenames, applies the required
    preprocessing and creates the pytorch dataloader; it then performs evaluation on each input file using a trained
    dynUNet model (random flipping augmentation is applied at inference).
    It uses the dynUNet model implemented in the MONAI framework
    (https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/nets/dynunet.py)
    which is inspired by the nnU-Net framework (https://arxiv.org/abs/1809.10486)
    Inference is performed in 2D slice-by-slice, all slices are then recombined together into the 3D volume.

    Args:
        input_data: str or list of strings, filenames of images to be processed
        config_info: dict, contains the configuration parameters to reload the trained model

    """
    """
    Read input and configuration parameters
    """

    val_files = create_data_list_of_dictionaries(input_data)

    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print("*** MONAI config: ")
    print_config()

    # print to log the parameter setups
    print("*** Network inference config: ")
    print(yaml.dump(config_info))

    # inference params
    nr_out_channels = config_info['inference']['nr_out_channels']
    spacing = config_info["inference"]["spacing"]
    prob_thr = config_info['inference']['probability_threshold']
    model_to_load = config_info['inference']['model_to_load']
    if not os.path.exists(model_to_load):
        raise FileNotFoundError('Trained model not found')
    patch_size = config_info["inference"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}".format(
            torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))
    """
    Data Preparation
    """
    print("***  Preparing data ... ")
    # data preprocessing for inference:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - resample to the training resolution in-plane (not along z)
    # - apply whitening
    # - convert to tensor
    val_transforms = Compose([
        LoadNiftid(keys=["image"]),
        AddChanneld(keys=["image"]),
        InPlaneSpacingd(
            keys=["image"],
            pixdim=spacing,
            mode="bilinear",
        ),
        NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
        ToTensord(keys=["image"]),
    ])
    # create a validation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=config_info['device']['num_workers'])

    def prepare_batch(batchdata):
        assert isinstance(batchdata,
                          dict), "prepare_batch expects dictionary input data."
        return ((batchdata["image"],
                 batchdata["label"]) if "label" in batchdata else
                (batchdata["image"], None))

    """
    Network preparation
    """
    print("***  Preparing network ... ")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    net = DynUNet(spatial_dims=2,
                  in_channels=1,
                  out_channels=nr_out_channels,
                  kernel_size=kernels,
                  strides=strides,
                  upsample_kernel_size=strides[1:],
                  norm_name="instance",
                  deep_supervision=True,
                  deep_supr_num=2,
                  res_block=False).to(current_device)
    """
    Set ignite evaluator to perform inference
    """
    print("***  Preparing evaluator ... ")
    if nr_out_channels == 1:
        do_sigmoid = True
        do_softmax = False
    elif nr_out_channels > 1:
        do_sigmoid = False
        do_softmax = True
    else:
        raise Exception("incompatible number of output channels")
    print("Using sigmoid={} and softmax={} as final activation".format(
        do_sigmoid, do_softmax))
    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=do_sigmoid, softmax=do_softmax),
        AsDiscreted(keys="pred",
                    argmax=True,
                    threshold_values=True,
                    logit_thresh=prob_thr),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=1)
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_to_load,
                         load_dict={"net": net},
                         map_location=torch.device('cpu')),
        SegmentationSaver(
            output_dir=config_info['output']['out_dir'],
            output_ext='.nii.gz',
            output_postfix=config_info['output']['out_postfix'],
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs = inputs.to(engine.state.device)
            if targets is not None:
                targets = targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2, ))
            flip_inputs_2 = torch.flip(inputs, dims=(3, ))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1,
                                                      self.network),
                                         dims=(2, ))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2,
                                                      self.network),
                                         dims=(3, ))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3,
                                                      self.network),
                                         dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        prepare_batch=prepare_batch,
        inferer=SlidingWindowInferer2D(roi_size=patch_size,
                                       sw_batch_size=4,
                                       overlap=0.0),
        post_transform=val_post_transforms,
        val_handlers=val_handlers,
        amp=False,
    )
    """
    Run inference
    """
    print("***  Running evaluator ... ")
    evaluator.run()
    print("Done!")

    return