Example #1
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred", device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=True),
         AsDiscreted(keys="pred", argmax=True),
         SqueezeDimd(keys="pred", dim=0),
         ToNumpyd(keys="pred"),
         Restored(keys="pred", ref_image="image"),
     ]
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred", device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1),
         AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None),
         SqueezeDimd(keys="pred", dim=0),
         ToNumpyd(keys=("image", "pred")),
         PostFilterLabeld(keys="pred", image="image"),
         FindContoursd(keys="pred", labels=self.labels),
     ]
def run_training(train_file_list, valid_file_list, config_info):
    """
    Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks:
        * Data Preparation: Extract the filenames and prepare the training/validation processing transforms
        * Load Data: Load training and validation data to PyTorch DataLoader
        * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler
        * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation
            during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach
            on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric.
        * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop.
        * Run training: The MONAI trainer is run, performing training and validation during training.
    Args:
        train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format.
        valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format.
        config_info: dict, contains configuration parameters for sampling, network and training.
            See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields.
    """

    """
    Read input and configuration parameters
    """
    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # extract network parameters, perform checks/set defaults if not present and print them to log
    if 'seg_labels' in config_info['training'].keys():
        seg_labels = config_info['training']['seg_labels']
    else:
        seg_labels = [1]
    nr_out_channels = len(seg_labels)
    print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels))
    patch_size = config_info["training"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    spacing = config_info["training"]["spacing"]
    print("Bringing all images to spacing = {}".format(spacing))

    if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None:
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise FileNotFoundError("Cannot find model: {}".format(model_to_load))
        else:
            print("Loading model from {}".format(model_to_load))
    else:
        model_to_load = None

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))

    # set determinism if required
    if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None:
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    if seed is not None:
        print("Using determinism with seed = {}\n".format(seed))
        set_determinism(seed=seed)

    """
    Setup data output directory
    """
    out_model_dir = os.path.join(config_info['output']['out_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['out_postfix'])
    print("Saving to directory {}\n".format(out_model_dir))
    # create cache directory to store results for Persistent Dataset
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    """
    Data preparation
    """
    # Read the input files for training and validation
    print("*** Loading input data for training...")

    train_files = create_data_list_of_dictionaries(train_file_list)
    print("Number of inputs for training = {}".format(len(train_files)))

    val_files = create_data_list_of_dictionaries(valid_file_list)
    print("Number of inputs for validation = {}".format(len(val_files)))

    # Define MONAI processing transforms for the training data. This includes:
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1]
    # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M])
    # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd,
    #       RandFlipd)
    # - ToTensor: convert to pytorch tensor
    train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                        mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # Define MONAI processing transforms for the validation data
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - ToTensor: convert to pytorch tensor
    # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )

    """
    Load data 
    """
    # create training data loader
    train_ds = PersistentDataset(data=train_files, transform=train_transforms,
                                 cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=config_info['training']['batch_size_train'],
                              shuffle=True,
                              num_workers=config_info['device']['num_workers'])
    check_train_data = misc.first(train_loader)
    print("Training data tensor shapes:")
    print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape))

    # create validation data loader
    if config_info['training']['batch_size_valid'] != 1:
        raise Exception("Batch size different from 1 at validation ar currently not supported")
    val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config_info['device']['num_workers'])
    check_valid_data = misc.first(val_loader)
    print("Validation data tensor shapes (Example):")
    print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape))

    """
    Network preparation
    """
    print("*** Preparing the network ...")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # initialise the network
    net = DynUNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=nr_out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
        res_block=False,
    ).to(current_device)
    print(net)

    # define the loss function
    loss_function = choose_loss_function(nr_out_channels, config_info)

    # define the optimiser and the learning rate scheduler
    opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9
    )

    """
    MONAI evaluator
    """
    print("*** Preparing the dynUNet evaluator engine...\n")
    # val_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                output_transform=lambda x: None,
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True,
                        file_prefix='best_valid'),
    ]
    if config_info['output']['val_image_to_tensorboad']:
        val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                                    batch_transform=lambda x: (x["image"], x["label"]),
                                                    output_transform=lambda x: x["pred"], interval=2))

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2,))
            flip_inputs_2 = torch.flip(inputs, dims=(3,))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0),
        post_transform=None,
        key_val_metric={
            "Mean_dice": MeanDice(
                include_background=False,
                to_onehot_y=True,
                mutually_exclusive=True,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=False,
    )

    """
    MONAI trainer
    """
    print("*** Preparing the dynUNet trainer engine...\n")
    # train_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )

    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    # define event handlers for the trainer
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(summary_writer=writer_train,
                                log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss",
                                output_transform=lambda x: x["loss"],
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt},
                        save_final=True,
                        save_interval=2, epoch_level=True,
                        n_saved=config_info['output']['max_nr_models_saved']),
    ]
    if model_to_load is not None:
        train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt}))

    # define customized trainer
    class DynUNetTrainer(SupervisedTrainer):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)

            def _compute_loss(preds, label):
                labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]]
                return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))])

            self.network.train()
            self.optimizer.zero_grad()
            if self.amp and self.scaler is not None:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network)
                    loss = _compute_loss(predictions, targets)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.inferer(inputs, self.network)
                loss = _compute_loss(predictions, targets).mean()
                loss.backward()
                self.optimizer.step()
            return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()}

    trainer = DynUNetTrainer(
        device=current_device,
        max_epochs=config_info['training']['nr_train_epochs'],
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=False,
    )

    """
    Run training
    """
    print("*** Run training...")
    trainer.run()
    print("Done!")
Example #4
0
 def test_invalid_inputs(self, exception, input_param, test_data):
     with self.assertRaises(exception):
         SqueezeDimd(**input_param)(test_data)
Example #5
0
 def test_shape(self, input_param, test_data, expected_shape):
     result = SqueezeDimd(**input_param)(test_data)
     self.assertTupleEqual(result["img"].shape, expected_shape)
     self.assertTupleEqual(result["seg"].shape, expected_shape)
Example #6
0
 def test_invalid_inputs(self, input_param, test_data):
     with self.assertRaises(AssertionError):
         result = SqueezeDimd(**input_param)(test_data)
def main():

    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(description='Run basic UNet with MONAI.')
    parser.add_argument('--config', dest='config', metavar='config', type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # training and validation params
    loss_type = config_info['training']['loss_type']
    batch_size_train = config_info['training']['batch_size_train']
    batch_size_valid = config_info['training']['batch_size_valid']
    lr = float(config_info['training']['lr'])
    nr_train_epochs = config_info['training']['nr_train_epochs']
    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    sliding_window_validation = config_info['training']['sliding_window_validation']
    # data params
    data_root = config_info['data']['data_root']
    training_list = config_info['data']['training_list']
    validation_list = config_info['data']['validation_list']
    # model saving
    # model saving
    out_model_dir = os.path.join(config_info['output']['out_model_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['output_subfix'])
    print("Saving to directory ", out_model_dir)
    max_nr_models_saved = config_info['output']['max_nr_models_saved']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)

    """
    Data Preparation
    """
    # create training and validation data lists
    train_files = create_data_list(data_folder_list=data_root,
                                   subject_list=training_list,
                                   img_postfix='_Image',
                                   label_postfix='_Label')

    print(len(train_files))
    print(train_files[0])
    print(train_files[-1])

    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=validation_list,
                                 img_postfix='_Image',
                                 label_postfix='_Label')
    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for training:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    # - define 2D patches to be extracted
    # - add data augmentation (random rotation and random flip)
    # - squeeze to 2D
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AddChanneld(keys=['img', 'seg']),
        NormalizeIntensityd(keys=['img']),
        Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
        RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False),
        RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        ToTensord(keys=['img', 'seg'])
    ])
    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size_train,
                              shuffle=True, num_workers=num_workers,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    check_train_data = monai.utils.misc.first(train_loader)
    print("Training data tensor shapes")
    print(check_train_data['img'].shape, check_train_data['seg'].shape)

    # data preprocessing for validation:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    if sliding_window_validation:
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = False
        collate_fn_to_use = None
    else:
        # - add extraction of 2D slices from validation set to emulate how loss is computed at training
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
            RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = True
        collate_fn_to_use = list_data_collate
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_valid,
                            shuffle=do_shuffle,
                            collate_fn=collate_fn_to_use,
                            num_workers=num_workers)
    check_valid_data = monai.utils.misc.first(val_loader)
    print("Validation data tensor shapes")
    print(check_valid_data['img'].shape, check_valid_data['seg'].shape)

    """
    Network preparation
    """
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )

    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.cuda.current_device()

    """
    Training loop
    """
    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid"))
    net.to(device)
    for epoch in range(nr_train_epochs):
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch + 1, nr_train_epochs))
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device)
            opt.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            opt.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item()))
            writer_train.add_scalar('loss', loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss))

        if (epoch + 1) % validation_every_n_epochs == 0:
            net.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                check_tot_validation = 0
                for val_data in val_loader:
                    check_tot_validation += 1
                    val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device)
                    if sliding_window_validation:
                        print('Running sliding window validation')
                        roi_size = (96, 96, 1)
                        val_outputs = sliding_window_inference(val_images, roi_size, batch_size_valid, net)
                        value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True,
                                                 to_onehot_y=False, add_sigmoid=True)
                        metric_count += len(value)
                        metric_sum += value.sum().item()
                    else:
                        print('Running 2D validation')
                        # compute validation
                        val_outputs = net(val_images)
                        value = 1.0 - loss_function(val_outputs, val_labels)
                        metric_count += 1
                        metric_sum += value.item()
                print("Total number of data in validation: %d" % check_tot_validation)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), os.path.join(out_model_dir, 'best_metric_model.pth'))
                    print('saved new best metric model')
                print("current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d"
                      % (epoch + 1, metric, best_metric, best_metric_epoch))
                epoch_len = len(train_ds) // train_loader.batch_size
                writer_valid.add_scalar('loss', 1.0 - metric, epoch_len * epoch + step)
                writer_valid.add_scalar('val_mean_dice', metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer_valid, index=0, tag='image')
                plot_2d_or_3d_image(val_labels, epoch + 1, writer_valid, index=0, tag='label')
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer_valid, index=0, tag='output')

    print('train completed, best_metric: %0.4f  at epoch: %d' % (best_metric, best_metric_epoch))
    writer_train.close()
    writer_valid.close()
Example #8
0
def main():
    """
    Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using
    ignite to manage training and validation loop and checkpointing
    :return:
    """
    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(
        description='Run basic UNet with MONAI - Ignite version.')
    parser.add_argument('--config',
                        dest='config',
                        metavar='config',
                        type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # training and validation params
    loss_type = config_info['training']['loss_type']
    batch_size_train = config_info['training']['batch_size_train']
    batch_size_valid = config_info['training']['batch_size_valid']
    lr = float(config_info['training']['lr'])
    lr_decay = config_info['training']['lr_decay']
    if lr_decay is not None:
        lr_decay = float(lr_decay)
    nr_train_epochs = config_info['training']['nr_train_epochs']
    validation_every_n_epochs = config_info['training'][
        'validation_every_n_epochs']
    sliding_window_validation = config_info['training'][
        'sliding_window_validation']
    if 'model_to_load' in config_info['training'].keys():
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise BlockingIOError(
                "cannot find model: {}".format(model_to_load))
    else:
        model_to_load = None
    if 'manual_seed' in config_info['training'].keys():
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    # data params
    data_root = config_info['data']['data_root']
    training_list = config_info['data']['training_list']
    validation_list = config_info['data']['validation_list']
    # model saving
    out_model_dir = os.path.join(
        config_info['output']['out_model_dir'],
        datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
        config_info['output']['output_subfix'])
    print("Saving to directory ", out_model_dir)
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    max_nr_models_saved = config_info['output']['max_nr_models_saved']
    val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)
    if seed is not None:
        # set manual seed if required (both numpy and torch)
        set_determinism(seed=seed)
        # # set torch only seed
        # torch.manual_seed(seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
    """
    Data Preparation
    """
    # create cache directory to store results for Persistent Dataset
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    # create training and validation data lists
    train_files = create_data_list(data_folder_list=data_root,
                                   subject_list=training_list,
                                   img_postfix='_Image',
                                   label_postfix='_Label')

    print(len(train_files))
    print(train_files[0])
    print(train_files[-1])

    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=validation_list,
                                 img_postfix='_Image',
                                 label_postfix='_Label')
    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for training:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    # - define 2D patches to be extracted
    # - add data augmentation (random rotation and random flip)
    # - squeeze to 2D
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AddChanneld(keys=['img', 'seg']),
        NormalizeIntensityd(keys=['img']),
        Resized(keys=['img', 'seg'],
                spatial_size=[96, 96],
                interp_order=[1, 0],
                anti_aliasing=[True, False]),
        RandSpatialCropd(keys=['img', 'seg'],
                         roi_size=[96, 96, 1],
                         random_size=False),
        RandRotated(keys=['img', 'seg'],
                    degrees=90,
                    prob=0.2,
                    spatial_axes=[0, 1],
                    interp_order=[1, 0],
                    reshape=False),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        ToTensord(keys=['img', 'seg'])
    ])
    # create a training data loader
    # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    train_ds = monai.data.PersistentDataset(data=train_files,
                                            transform=train_transforms,
                                            cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size_train,
                              shuffle=True,
                              num_workers=num_workers,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    # check_train_data = monai.utils.misc.first(train_loader)
    # print("Training data tensor shapes")
    # print(check_train_data['img'].shape, check_train_data['seg'].shape)

    # data preprocessing for validation:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    if sliding_window_validation:
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = False
        collate_fn_to_use = None
    else:
        # - add extraction of 2D slices from validation set to emulate how loss is computed at training
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            RandSpatialCropd(keys=['img', 'seg'],
                             roi_size=[96, 96, 1],
                             random_size=False),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = True
        collate_fn_to_use = list_data_collate
    # create a validation data loader
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    val_ds = monai.data.PersistentDataset(data=val_files,
                                          transform=val_transforms,
                                          cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_valid,
                            shuffle=do_shuffle,
                            collate_fn=collate_fn_to_use,
                            num_workers=num_workers)
    # check_valid_data = monai.utils.misc.first(val_loader)
    # print("Validation data tensor shapes")
    # print(check_valid_data['img'].shape, check_valid_data['seg'].shape)
    """
    Network preparation
    """
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )

    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.cuda.current_device()
    if lr_decay is not None:
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt,
                                                              gamma=lr_decay,
                                                              last_epoch=-1)
    """
    Set ignite trainer
    """

    # function to manage batch at training
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch['img'], batch['seg']), device,
                              non_blocking)

    trainer = create_supervised_trainer(model=net,
                                        optimizer=opt,
                                        loss_fn=loss_function,
                                        device=device,
                                        non_blocking=False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    if model_to_load is not None:
        checkpoint_handler = CheckpointLoader(load_path=model_to_load,
                                              load_dict={
                                                  'net': net,
                                                  'opt': opt,
                                              })
        checkpoint_handler.attach(trainer)
        state = trainer.state_dict()
    else:
        checkpoint_handler = ModelCheckpoint(out_model_dir,
                                             'net',
                                             n_saved=max_nr_models_saved,
                                             require_empty=False)
        # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params)
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'net': net,
                                      'opt': opt
                                  })

    # StatsHandler prints loss at every iteration and print metrics at every epoch
    train_stats_handler = StatsHandler(name='trainer')
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_train)
    train_tensorboard_stats_handler.attach(trainer)

    if lr_decay is not None:
        print("Using Exponential LR decay")
        lr_schedule_handler = LrScheduleHandler(lr_scheduler,
                                                print_lr=True,
                                                name="lr_scheduler",
                                                writer=writer_train)
        lr_schedule_handler.attach(trainer)
    """
    Set ignite evaluator to perform validation at training
    """
    # set parameters for validation
    metric_name = 'Mean_Dice'
    # add evaluation metric to the evaluator engine
    val_metrics = {
        "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False),
        "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False)
    }

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch['img'].to(device), batch['seg'].to(
                device)
            roi_size = (96, 96, 1)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 batch_size_valid, net)
            return seg_probs, val_labels

    if sliding_window_validation:
        # use sliding window inference at validation
        print("3D evaluator is used")
        net.to(device)
        evaluator = Engine(_sliding_window_processor)
        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)
    else:
        # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
        # user can add output_transform to return other values
        print("2D evaluator is used")
        evaluator = create_supervised_evaluator(model=net,
                                                metrics=val_metrics,
                                                device=device,
                                                non_blocking=True,
                                                prepare_batch=prepare_batch)

    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    # early_stopper = EarlyStopping(patience=4,
    #                               score_function=stopping_fn_from_metric(metric_name),
    #                               trainer=trainer)
    # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid"))
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_valid,
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    if val_image_to_tensorboad:
        val_tensorboard_image_handler = TensorBoardImageHandler(
            summary_writer=writer_valid,
            batch_transform=lambda batch: (batch['img'], batch['seg']),
            output_transform=lambda output: predict_segmentation(output[0]),
            global_iter_transform=lambda x: trainer.state.epoch)
        evaluator.add_event_handler(
            event_name=Events.ITERATION_COMPLETED(every=1),
            handler=val_tensorboard_image_handler)
    """
    Run training
    """
    state = trainer.run(train_loader, nr_train_epochs)
    print("Done!")
Example #9
0
 def test_invalid_inputs(self, input_param, test_data):
     with self.assertRaises(ValueError):
         SqueezeDimd(**input_param)(test_data)
Example #10
0
def image_mixing(data, seed=None):
    #random.seed(seed)

    file_list = [x for x in data if int(x['_label']) == 1]
    random.shuffle(file_list)

    crop_foreground = CropForegroundd(keys=["image"],
                                      source_key="image",
                                      margin=(0, 0, 0),
                                      select_fn=lambda x: x != 0)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    resize2 = Resized(keys=["image"],
                      spatial_size=(int(512 * 0.75), int(512 * 0.75), -1),
                      mode="area")
    resize1 = Resized(keys=["image"],
                      spatial_size=(-1, -1, 40),
                      mode="nearest")
    gauss = GaussianSmooth(sigma=(1., 1., 0))
    gauss2 = GaussianSmooth(sigma=(2.0, 2.0, 0))
    affine = Affined(keys=["image"],
                     scale_params=(1.0, 2.0, 1.0),
                     padding_mode='zeros')

    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        affine,
        crop_foreground,
        resize1,
        resize2,
        SqueezeDimd(keys=["image"]),
    ])

    dirs = setup_directories()
    data_dir = dirs['data']
    mixed_images_dir = os.path.join(data_dir, 'mixed_images')
    if not os.path.exists(mixed_images_dir):
        os.mkdir(mixed_images_dir)

    for img1, img2 in itertools.combinations(file_list, 2):

        img1 = {'image': img1["image"], 'seg': img1['seg']}
        img2 = {'image': img2["image"], 'seg': img2['seg']}

        img1_data = common_transform(img1)["image"]
        img2_data = common_transform(img2)["image"]
        img1_mask, img2_mask = (img1_data > 0), (img2_data > 0)
        img_presek = np.logical_and(img1_mask, img2_mask)
        img = np.maximum(img_presek * img1_data, img_presek * img2_data)

        multi_slice_viewer(img, "img1")

        loop = True
        while loop:
            save = input("Save image [y/n/e]: ")
            if save.lower() == 'y':
                loop = False
                k = str(time.time()).encode('utf-8')
                h = blake2b(key=k, digest_size=16)
                name = h.hexdigest() + '.nii.gz'
                out_path = os.path.join(mixed_images_dir, name)
                write_nifti(img, out_path, resample=False)
            elif save.lower() == 'n':
                loop = False
                break
            elif save.lower() == 'e':
                print("exeting")
                exit()
            else:
                print("wrong input!")