Exemple #1
0
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images,
                      segs,
                      transform=imtrans,
                      seg_transform=segtrans,
                      image_only=False)

    device = torch.device("cuda:0")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda x: x[2],
        output_transform=lambda output: predict_segmentation(output[0]),
    )
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path="./runs/net_checkpoint_100.pth",
                                  load_dict={"net": net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds,
                        batch_size=1,
                        num_workers=1,
                        pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    print(state)
    shutil.rmtree(tempdir)
Exemple #2
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)

    # define transforms for image
    val_transforms = Compose(
        [ScaleIntensity(),
         AddChannel(),
         Resize((96, 96, 96)),
         ToTensor()])
    # define nifti dataset
    val_ds = NiftiDataset(image_files=images,
                          labels=labels,
                          transform=val_transforms,
                          image_only=False)
    # create DenseNet121
    net = monai.networks.nets.densenet.densenet121(spatial_dims=3,
                                                   in_channels=1,
                                                   out_channels=2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}

    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch[0], batch[1]), device, non_blocking)

    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    prediction_saver = ClassificationSaver(
        output_dir="tempdir",
        batch_transform=lambda batch: batch[2],
        output_transform=lambda output: output[0].argmax(1),
    )
    prediction_saver.attach(evaluator)

    # the model was trained by "densenet_training_array" example
    CheckpointLoader(load_path="./runs_array/net_checkpoint_20.pth",
                     load_dict={
                         "net": net
                     }).attach(evaluator)

    # create a validation data loader
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    state = evaluator.run(val_loader)
    print(state)
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# for the array data format, assume the 3rd item of batch data is the meta_data
file_saver = SegmentationSaver(
    output_dir='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
    name='evaluator',
    batch_transform=lambda x: x[2],
    output_transform=lambda output: predict_segmentation(output[0]))
file_saver.attach(evaluator)

# the model was trained by "unet_training_array" example
ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth',
                              load_dict={'net': net})
ckpt_saver.attach(evaluator)

# sliding window inference for one image at every iteration
loader = DataLoader(ds,
                    batch_size=1,
                    num_workers=1,
                    pin_memory=torch.cuda.is_available())
state = evaluator.run(loader)
shutil.rmtree(tempdir)
Exemple #4
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path="./runs/net_key_metric=0.9101.pth",
                         load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
    )
    evaluator.run()
    shutil.rmtree(tempdir)
Exemple #5
0
def main():
    """
    Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using
    ignite to manage training and validation loop and checkpointing
    :return:
    """
    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(
        description='Run basic UNet with MONAI - Ignite version.')
    parser.add_argument('--config',
                        dest='config',
                        metavar='config',
                        type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # training and validation params
    loss_type = config_info['training']['loss_type']
    batch_size_train = config_info['training']['batch_size_train']
    batch_size_valid = config_info['training']['batch_size_valid']
    lr = float(config_info['training']['lr'])
    lr_decay = config_info['training']['lr_decay']
    if lr_decay is not None:
        lr_decay = float(lr_decay)
    nr_train_epochs = config_info['training']['nr_train_epochs']
    validation_every_n_epochs = config_info['training'][
        'validation_every_n_epochs']
    sliding_window_validation = config_info['training'][
        'sliding_window_validation']
    if 'model_to_load' in config_info['training'].keys():
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise BlockingIOError(
                "cannot find model: {}".format(model_to_load))
    else:
        model_to_load = None
    if 'manual_seed' in config_info['training'].keys():
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    # data params
    data_root = config_info['data']['data_root']
    training_list = config_info['data']['training_list']
    validation_list = config_info['data']['validation_list']
    # model saving
    out_model_dir = os.path.join(
        config_info['output']['out_model_dir'],
        datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
        config_info['output']['output_subfix'])
    print("Saving to directory ", out_model_dir)
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    max_nr_models_saved = config_info['output']['max_nr_models_saved']
    val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)
    if seed is not None:
        # set manual seed if required (both numpy and torch)
        set_determinism(seed=seed)
        # # set torch only seed
        # torch.manual_seed(seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
    """
    Data Preparation
    """
    # create cache directory to store results for Persistent Dataset
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    # create training and validation data lists
    train_files = create_data_list(data_folder_list=data_root,
                                   subject_list=training_list,
                                   img_postfix='_Image',
                                   label_postfix='_Label')

    print(len(train_files))
    print(train_files[0])
    print(train_files[-1])

    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=validation_list,
                                 img_postfix='_Image',
                                 label_postfix='_Label')
    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for training:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    # - define 2D patches to be extracted
    # - add data augmentation (random rotation and random flip)
    # - squeeze to 2D
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AddChanneld(keys=['img', 'seg']),
        NormalizeIntensityd(keys=['img']),
        Resized(keys=['img', 'seg'],
                spatial_size=[96, 96],
                interp_order=[1, 0],
                anti_aliasing=[True, False]),
        RandSpatialCropd(keys=['img', 'seg'],
                         roi_size=[96, 96, 1],
                         random_size=False),
        RandRotated(keys=['img', 'seg'],
                    degrees=90,
                    prob=0.2,
                    spatial_axes=[0, 1],
                    interp_order=[1, 0],
                    reshape=False),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        ToTensord(keys=['img', 'seg'])
    ])
    # create a training data loader
    # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    train_ds = monai.data.PersistentDataset(data=train_files,
                                            transform=train_transforms,
                                            cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size_train,
                              shuffle=True,
                              num_workers=num_workers,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    # check_train_data = monai.utils.misc.first(train_loader)
    # print("Training data tensor shapes")
    # print(check_train_data['img'].shape, check_train_data['seg'].shape)

    # data preprocessing for validation:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    if sliding_window_validation:
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = False
        collate_fn_to_use = None
    else:
        # - add extraction of 2D slices from validation set to emulate how loss is computed at training
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            RandSpatialCropd(keys=['img', 'seg'],
                             roi_size=[96, 96, 1],
                             random_size=False),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = True
        collate_fn_to_use = list_data_collate
    # create a validation data loader
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    val_ds = monai.data.PersistentDataset(data=val_files,
                                          transform=val_transforms,
                                          cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_valid,
                            shuffle=do_shuffle,
                            collate_fn=collate_fn_to_use,
                            num_workers=num_workers)
    # check_valid_data = monai.utils.misc.first(val_loader)
    # print("Validation data tensor shapes")
    # print(check_valid_data['img'].shape, check_valid_data['seg'].shape)
    """
    Network preparation
    """
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )

    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.cuda.current_device()
    if lr_decay is not None:
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt,
                                                              gamma=lr_decay,
                                                              last_epoch=-1)
    """
    Set ignite trainer
    """

    # function to manage batch at training
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch['img'], batch['seg']), device,
                              non_blocking)

    trainer = create_supervised_trainer(model=net,
                                        optimizer=opt,
                                        loss_fn=loss_function,
                                        device=device,
                                        non_blocking=False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    if model_to_load is not None:
        checkpoint_handler = CheckpointLoader(load_path=model_to_load,
                                              load_dict={
                                                  'net': net,
                                                  'opt': opt,
                                              })
        checkpoint_handler.attach(trainer)
        state = trainer.state_dict()
    else:
        checkpoint_handler = ModelCheckpoint(out_model_dir,
                                             'net',
                                             n_saved=max_nr_models_saved,
                                             require_empty=False)
        # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params)
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'net': net,
                                      'opt': opt
                                  })

    # StatsHandler prints loss at every iteration and print metrics at every epoch
    train_stats_handler = StatsHandler(name='trainer')
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_train)
    train_tensorboard_stats_handler.attach(trainer)

    if lr_decay is not None:
        print("Using Exponential LR decay")
        lr_schedule_handler = LrScheduleHandler(lr_scheduler,
                                                print_lr=True,
                                                name="lr_scheduler",
                                                writer=writer_train)
        lr_schedule_handler.attach(trainer)
    """
    Set ignite evaluator to perform validation at training
    """
    # set parameters for validation
    metric_name = 'Mean_Dice'
    # add evaluation metric to the evaluator engine
    val_metrics = {
        "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False),
        "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False)
    }

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch['img'].to(device), batch['seg'].to(
                device)
            roi_size = (96, 96, 1)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 batch_size_valid, net)
            return seg_probs, val_labels

    if sliding_window_validation:
        # use sliding window inference at validation
        print("3D evaluator is used")
        net.to(device)
        evaluator = Engine(_sliding_window_processor)
        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)
    else:
        # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
        # user can add output_transform to return other values
        print("2D evaluator is used")
        evaluator = create_supervised_evaluator(model=net,
                                                metrics=val_metrics,
                                                device=device,
                                                non_blocking=True,
                                                prepare_batch=prepare_batch)

    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    # early_stopper = EarlyStopping(patience=4,
    #                               score_function=stopping_fn_from_metric(metric_name),
    #                               trainer=trainer)
    # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid"))
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_valid,
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    if val_image_to_tensorboad:
        val_tensorboard_image_handler = TensorBoardImageHandler(
            summary_writer=writer_valid,
            batch_transform=lambda batch: (batch['img'], batch['seg']),
            output_transform=lambda output: predict_segmentation(output[0]),
            global_iter_transform=lambda x: trainer.state.epoch)
        evaluator.add_event_handler(
            event_name=Events.ITERATION_COMPLETED(every=1),
            handler=val_tensorboard_image_handler)
    """
    Run training
    """
    state = trainer.run(train_loader, nr_train_epochs)
    print("Done!")
Exemple #6
0
def run_inference_test(root_dir, model_file, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", output_postfix="act", sigmoid=True),
        AsDiscreted(keys="pred_act",
                    output_postfix="dis",
                    threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred_act_dis",
                                       applied_values=[1],
                                       output_postfix=None),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            batch_transform=lambda x: {
                "filename_or_obj": x["image.filename_or_obj"],
                "affine": x["image.affine"],
                "original_affine": x["image.original_affine"],
                "spatial_shape": x["image.spatial_shape"],
            },
            output_transform=lambda x: x["pred_act_dis"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x:
                     (x["pred_act_dis"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(
                output_transform=lambda x: (x["pred_act_dis"], x["label"]))
        },
        val_handlers=val_handlers,
    )
    evaluator.run()

    return evaluator.state.best_metric
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys=["img", "seg"]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    device = torch.device("cuda:0")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch["img"].to(device), batch["seg"].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # convert the necessary metadata from batch data
    SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda batch: {"filename_or_obj": batch["img.filename_or_obj"], "affine": batch["img.affine"]},
        output_transform=lambda output: predict_segmentation(output[0]),
    ).attach(evaluator)
    # the model was trained by "unet_training_dict" example
    CheckpointLoader(load_path="./runs/net_checkpoint_50.pth", load_dict={"net": net}).attach(evaluator)

    # sliding window inference for one image at every iteration
    val_loader = DataLoader(
        val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
    )
    state = evaluator.run(val_loader)
    print(state)
    shutil.rmtree(tempdir)
Exemple #8
0
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    net = DistributedDataParallel(net, device_ids=[args.local_rank])

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        CheckpointLoader(
            load_path="./runs/checkpoint_epoch=4.pth",
            load_dict={"net": net},
            # config mapping to expected GPU device
            map_location={"cuda:0": f"cuda:{args.local_rank}"},
        ),
    ]
    if dist.get_rank() == 0:
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        val_handlers.extend(
            [
                StatsHandler(output_transform=lambda x: None),
                SegmentationSaver(
                    output_dir="./runs/",
                    batch_transform=lambda batch: batch["image_meta_dict"],
                    output_transform=lambda output: output["pred"],
                ),
            ]
        )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True, output_transform=lambda x: (x["pred"], x["label"]), device=device,
            )
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
    dist.destroy_process_group()
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys=["image", "label"]),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_postprocessing = Compose(
        [
            ToTensord(keys=["pred", "label"]),
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
            # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
            SaveImaged(
                keys="pred", meta_keys=PostFix.meta("image"), output_dir=root_dir, output_postfix="seg_transform"
            ),
        ]
    )
    val_handlers = [
        StatsHandler(iteration_log=False),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            output_postfix="seg_handler",
            batch_transform=from_engine(PostFix.meta("image")),
            output_transform=from_engine("pred"),
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
        val_handlers=val_handlers,
        amp=bool(amp),
    )
    evaluator.run()

    return evaluator.state.best_metric

metric_name = 'Accuracy'
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: Accuracy()}
# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch)

# add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x: None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# for the array data format, assume the 3rd item of batch data is the meta_data
prediction_saver = ClassificationSaver(output_dir='tempdir', name='evaluator',
                                       batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj']},
                                       output_transform=lambda output: output[0].argmax(1))
prediction_saver.attach(evaluator)

# the model was trained by "densenet_training_dict" example
CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator)

# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

state = evaluator.run(val_loader)
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
    segtrans = Compose([AddChannel(), EnsureType()])
    ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
            val_data = decollate_batch(batch[2])
            for seg_prob, data in zip(seg_probs, val_data):
                save_image(seg_prob, data)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice().attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    print(state)