Ejemplo n.º 1
0
    def test_misconfig(self, input_params):
        with self.assertRaisesRegex(ValueError, "compatib"):
            dice_metric = MeanDice(**input_params)

            y_pred = torch.Tensor([[0, 1], [1, 0]])
            y = torch.ones((2, 1))
            dice_metric.update([y_pred, y])
Ejemplo n.º 2
0
    def test_compute(self, input_params, expected_avg, details_shape):
        dice_metric = MeanDice(**input_params)

        # set up engine

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        dice_metric.attach(engine=engine, name="mean_dice")
        # test input a list of channel-first tensor
        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        y = torch.Tensor([[[0], [1]], [[0], [1]]])
        engine.state.output = {"pred": y_pred, "label": y}
        engine.fire_event(Events.ITERATION_COMPLETED)

        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        y = torch.Tensor([[[0], [1]], [[1], [0]]])
        engine.state.output = {"pred": y_pred, "label": y}
        engine.fire_event(Events.ITERATION_COMPLETED)

        engine.fire_event(Events.EPOCH_COMPLETED)
        torch.testing.assert_allclose(engine.state.metrics["mean_dice"],
                                      expected_avg)
        self.assertTupleEqual(
            tuple(engine.state.metric_details["mean_dice"].shape),
            details_shape)
Ejemplo n.º 3
0
 def val_key_metric(self, context: Context):
     all_metrics = dict()
     all_metrics["val_mean_dice"] = MeanDice(output_transform=from_engine(
         ["pred", "label"]),
                                             include_background=False)
     for key_label in self._labels:
         if key_label != "background":
             all_metrics[key_label + "_dice"] = MeanDice(
                 output_transform=from_engine(
                     ["pred_" + key_label, "label_" + key_label]),
                 include_background=False)
     return all_metrics
Ejemplo n.º 4
0
def region_wise_metrics(regions, metric, prefix, keys=("pred", "label")):
    all_metrics = dict()
    all_metrics[metric] = MeanDice(output_transform=from_engine(keys),
                                   include_background=False)

    if regions:
        labels = regions if isinstance(regions, dict) else {
            name: idx
            for idx, name in enumerate(regions, start=1)
        }
        for name, idx in labels.items():
            all_metrics[f"{prefix}_{name}_dice"] = MeanDice(
                output_transform=from_engine_idx(keys, idx),
                include_background=False,
            )
    return all_metrics
Ejemplo n.º 5
0
    def test_compute(self, input_params, expected_avg):
        dice_metric = MeanDice(**input_params)

        y_pred = torch.Tensor([[[0], [1]], [[1], [0]]])
        y = torch.Tensor([[[0], [1]], [[0], [1]]])
        dice_metric.update([y_pred, y])

        y_pred = torch.Tensor([[[0], [1]], [[1], [0]]])
        y = torch.Tensor([[[0], [1]], [[1], [0]]])
        dice_metric.update([y_pred, y])

        avg_dice = dice_metric.compute()
        self.assertAlmostEqual(avg_dice, expected_avg, places=4)
Ejemplo n.º 6
0
    def test_shape_mismatch(self, input_params, _expected_avg, _details_shape):
        dice_metric = MeanDice(**input_params)
        with self.assertRaises((AssertionError, ValueError)):
            y_pred = torch.Tensor([[0, 1], [1, 0]])
            y = torch.ones((2, 3))
            dice_metric.update([y_pred, y])

        with self.assertRaises((AssertionError, ValueError)):
            y_pred = torch.Tensor([[0, 1], [1, 0]])
            y = torch.ones((3, 2))
            dice_metric.update([y_pred, y])
Ejemplo n.º 7
0
    def test_compute(self, input_params, expected_avg, details_shape):
        dice_metric = MeanDice(**input_params)

        # set up engine

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        dice_metric.attach(engine=engine, name="mean_dice")

        y_pred = torch.Tensor([[[0], [1]], [[1], [0]]])
        y = torch.Tensor([[[0], [1]], [[0], [1]]])
        dice_metric.update([y_pred, y])

        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        y = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
        dice_metric.update([y_pred, y])

        avg_dice = dice_metric.compute()
        self.assertAlmostEqual(avg_dice, expected_avg, places=4)
        self.assertTupleEqual(
            tuple(engine.state.metric_details["mean_dice"].shape),
            details_shape)
Ejemplo n.º 8
0
    def ensemble_evaluate(post_transforms, models):

        evaluator = EnsembleEvaluator(
            device=device,
            val_data_loader=test_loader,
            pred_keys=opt.pred_keys,
            networks=models,
            inferer=SlidingWindowInferer(roi_size=opt.patch_size,
                                         sw_batch_size=opt.batch_size,
                                         overlap=0.5),
            post_transform=post_transforms,
            key_val_metric={
                "test_mean_dice":
                MeanDice(
                    include_background=True,
                    output_transform=lambda x: (x["pred"], x["label"]),
                )
            },
        )
        evaluator.run()
Ejemplo n.º 9
0
def train(args):
    # load hyper parameters
    task_id = args.task_id
    fold = args.fold
    val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold,
                                                   args.expr_name)
    log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold)
    log_filename = os.path.join(val_output_dir, log_filename)
    interval = args.interval
    learning_rate = args.learning_rate
    max_epochs = args.max_epochs
    multi_gpu_flag = args.multi_gpu
    amp_flag = args.amp
    lr_decay_flag = args.lr_decay
    sw_batch_size = args.sw_batch_size
    tta_val = args.tta_val
    batch_dice = args.batch_dice
    window_mode = args.window_mode
    eval_overlap = args.eval_overlap
    local_rank = args.local_rank
    determinism_flag = args.determinism_flag
    determinism_seed = args.determinism_seed
    if determinism_flag:
        set_determinism(seed=determinism_seed)
        if local_rank == 0:
            print("Using deterministic training.")

    # transforms
    train_batch_size = data_loader_params[task_id]["batch_size"]
    if multi_gpu_flag:
        dist.init_process_group(backend="nccl", init_method="env://")

        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda")

    properties, val_loader = get_data(args, mode="validation")
    _, train_loader = get_data(args, batch_size=train_batch_size, mode="train")

    # produce the network
    checkpoint = args.checkpoint
    net = get_network(properties, task_id, val_output_dir, checkpoint)
    net = net.to(device)

    if multi_gpu_flag:
        net = DistributedDataParallel(module=net,
                                      device_ids=[device],
                                      find_unused_parameters=True)

    optimizer = torch.optim.SGD(
        net.parameters(),
        lr=learning_rate,
        momentum=0.99,
        weight_decay=3e-5,
        nesterov=True,
    )

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs)**0.9)
    # produce evaluator
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointSaver(save_dir=val_output_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = DynUNetEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        n_classes=len(properties["labels"]),
        inferer=SlidingWindowInferer(
            roi_size=patch_size[task_id],
            sw_batch_size=sw_batch_size,
            overlap=eval_overlap,
            mode=window_mode,
        ),
        post_transform=None,
        key_val_metric={
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=amp_flag,
        tta_val=tta_val,
    )
    # produce trainer
    loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice)
    train_handlers = []
    if lr_decay_flag:
        train_handlers += [
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True)
        ]

    train_handlers += [
        ValidationHandler(validator=evaluator,
                          interval=interval,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
    ]

    trainer = DynUNetTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=optimizer,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp_flag,
    )

    # run
    logger = logging.getLogger()

    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # Setup file handler
    fhandler = logging.FileHandler(log_filename)
    fhandler.setLevel(logging.INFO)
    fhandler.setFormatter(formatter)

    # Configure stream handler for the cells
    chandler = logging.StreamHandler()
    chandler.setLevel(logging.INFO)
    chandler.setFormatter(formatter)

    # Add both handlers
    if local_rank == 0:
        logger.addHandler(fhandler)
        logger.addHandler(chandler)
        logger.setLevel(logging.INFO)

    trainer.run()
Ejemplo n.º 10
0
def validation(args):
    # load hyper parameters
    task_id = args.task_id
    sw_batch_size = args.sw_batch_size
    tta_val = args.tta_val
    window_mode = args.window_mode
    eval_overlap = args.eval_overlap
    multi_gpu_flag = args.multi_gpu
    local_rank = args.local_rank
    amp = args.amp

    # produce the network
    checkpoint = args.checkpoint
    val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, args.fold,
                                                   args.expr_name)

    if multi_gpu_flag:
        dist.init_process_group(backend="nccl", init_method="env://")
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda")

    properties, val_loader = get_data(args, mode="validation")
    net = get_network(properties, task_id, val_output_dir, checkpoint)
    net = net.to(device)

    if multi_gpu_flag:
        net = DistributedDataParallel(module=net,
                                      device_ids=[device],
                                      find_unused_parameters=True)

    n_classes = len(properties["labels"])

    net.eval()

    evaluator = DynUNetEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        n_classes=n_classes,
        inferer=SlidingWindowInferer(
            roi_size=patch_size[task_id],
            sw_batch_size=sw_batch_size,
            overlap=eval_overlap,
            mode=window_mode,
        ),
        post_transform=None,
        key_val_metric={
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        additional_metrics=None,
        amp=amp,
        tta_val=tta_val,
    )

    evaluator.run()
    if local_rank == 0:
        print(evaluator.state.metrics)
        results = evaluator.state.metric_details["val_mean_dice"]
        if n_classes > 2:
            for i in range(n_classes - 1):
                print("mean dice for label {} is {}".format(
                    i + 1, results[:, i].mean()))
Ejemplo n.º 11
0
    def train(index):

        # ---------- Build the nn-Unet network ------------

        if opt.resolution is None:
            sizes, spacings = opt.patch_size, opt.spacing
        else:
            sizes, spacings = opt.patch_size, opt.resolution

        strides, kernels = [], []

        while True:
            spacing_ratio = [sp / min(spacings) for sp in spacings]
            stride = [
                2 if ratio <= 2 and size >= 8 else 1
                for (ratio, size) in zip(spacing_ratio, sizes)
            ]
            kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
            if all(s == 1 for s in stride):
                break
            sizes = [i / j for i, j in zip(sizes, stride)]
            spacings = [i * j for i, j in zip(spacings, stride)]
            kernels.append(kernel)
            strides.append(stride)
        strides.insert(0, len(spacings) * [1])
        kernels.append(len(spacings) * [3])

        net = monai.networks.nets.DynUNet(
            spatial_dims=3,
            in_channels=opt.in_channels,
            out_channels=opt.out_channels,
            kernel_size=kernels,
            strides=strides,
            upsample_kernel_size=strides[1:],
            res_block=True,
            # act=act_type,
            # norm=Norm.BATCH,
        ).to(device)

        from torch.autograd import Variable
        from torchsummaryX import summary

        data = Variable(
            torch.randn(int(opt.batch_size), int(opt.in_channels),
                        int(opt.patch_size[0]), int(opt.patch_size[1]),
                        int(opt.patch_size[2]))).cuda()

        out = net(data)
        summary(net, data)
        print("out size: {}".format(out.size()))

        # if opt.preload is not None:
        #     net.load_state_dict(torch.load(opt.preload))

        # ---------- ------------------------ ------------

        optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9)

        loss_function = monai.losses.DiceCELoss(sigmoid=True)

        val_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1])
        ])

        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(save_dir="./runs/",
                            save_dict={"net": net},
                            save_key_metric=True),
        ]

        evaluator = SupervisedEvaluator(
            device=device,
            val_data_loader=val_loaders[index],
            network=net,
            inferer=SlidingWindowInferer(roi_size=opt.patch_size,
                                         sw_batch_size=opt.batch_size,
                                         overlap=0.5),
            post_transform=val_post_transforms,
            key_val_metric={
                "val_mean_dice":
                MeanDice(
                    include_background=True,
                    output_transform=lambda x: (x["pred"], x["label"]),
                )
            },
            val_handlers=val_handlers)

        train_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ])

        train_handlers = [
            ValidationHandler(validator=evaluator,
                              interval=5,
                              epoch_level=True),
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir="./runs/",
                            save_dict={
                                "net": net,
                                "opt": optim
                            },
                            save_final=True,
                            epoch_level=True),
        ]

        trainer = SupervisedTrainer(
            device=device,
            max_epochs=opt.epochs,
            train_data_loader=train_loaders[index],
            network=net,
            optimizer=optim,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=train_post_transforms,
            amp=False,
            train_handlers=train_handlers,
        )
        trainer.run()
        return net
Ejemplo n.º 12
0
def create_trainer(args):
    set_determinism(seed=args.seed)

    multi_gpu = args.multi_gpu
    local_rank = args.local_rank
    if multi_gpu:
        dist.init_process_group(backend="nccl", init_method="env://")
        device = torch.device("cuda:{}".format(local_rank))
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda" if args.use_gpu else "cpu")

    pre_transforms = get_pre_transforms(args.roi_size, args.model_size,
                                        args.dimensions)
    click_transforms = get_click_transforms()
    post_transform = get_post_transforms()

    train_loader, val_loader = get_loaders(args, pre_transforms)

    # define training components
    network = get_network(args.network, args.channels,
                          args.dimensions).to(device)
    if multi_gpu:
        network = torch.nn.parallel.DistributedDataParallel(
            network, device_ids=[local_rank], output_device=local_rank)

    if args.resume:
        logging.info('{}:: Loading Network...'.format(local_rank))
        map_location = {"cuda:0": "cuda:{}".format(local_rank)}
        network.load_state_dict(
            torch.load(args.model_filepath, map_location=map_location))

    # define event-handlers for engine
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=args.output,
                                output_transform=lambda x: None),
        DeepgrowStatsHandler(log_dir=args.output,
                             tag_name='val_dice',
                             image_interval=args.image_interval),
        CheckpointSaver(save_dir=args.output,
                        save_dict={"net": network},
                        save_key_metric=True,
                        save_final=True,
                        save_interval=args.save_interval,
                        final_filename='model.pt')
    ]
    val_handlers = val_handlers if local_rank == 0 else None

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_val_interactions,
            key_probability='probability',
            train=False),
        inferer=SimpleInferer(),
        post_transform=post_transform,
        key_val_metric={
            "val_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers)

    loss_function = DiceLoss(sigmoid=True, squared_pred=True)
    optimizer = torch.optim.Adam(network.parameters(), args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=5000,
                                                   gamma=0.1)

    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator,
                          interval=args.val_freq,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=args.output,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=args.output,
                        save_dict={
                            "net": network,
                            "opt": optimizer,
                            "lr": lr_scheduler
                        },
                        save_interval=args.save_interval * 2,
                        save_final=True,
                        final_filename='checkpoint.pt'),
    ]
    train_handlers = train_handlers if local_rank == 0 else train_handlers[:2]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=args.epochs,
        train_data_loader=train_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_train_interactions,
            key_probability='probability',
            train=True),
        optimizer=optimizer,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=post_transform,
        amp=args.amp,
        key_train_metric={
            "train_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
    )
    return trainer
Ejemplo n.º 13
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    ################################ DATASET ################################
    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["image", "label"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    ################################ DATASET ################################
    
    ################################ NETWORK ################################
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    ################################ NETWORK ################################
    
    ################################ LOSS ################################    
    loss = monai.losses.DiceLoss(sigmoid=True)
    ################################ LOSS ################################
    
    ################################ OPT ################################
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    ################################ OPT ################################
    
    ################################ LR ################################
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
    ################################ LR ################################
    
    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None),
        TensorBoardImageHandler(
            log_dir="./runs/",
            batch_transform=lambda x: (x["image"], x["label"]),
            output_transform=lambda x: x["pred"],
        ),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
    trainer.run()
Ejemplo n.º 14
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        Keys.IMAGE: img,
        Keys.LABEL: seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        Keys.IMAGE: img,
        Keys.LABEL: seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]),
        AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1),
        ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]),
        RandCropByPosNegLabeld(keys=[Keys.IMAGE, Keys.LABEL],
                               label_key=Keys.LABEL,
                               size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=[Keys.IMAGE, Keys.LABEL],
                      prob=0.5,
                      spatial_axes=[0, 2]),
        ToTensord(keys=[Keys.IMAGE, Keys.LABEL]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]),
        AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1),
        ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]),
        ToTensord(keys=[Keys.IMAGE, Keys.LABEL]),
    ])

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    val_handlers = [StatsHandler(output_transform=lambda x: None)]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        val_handlers=val_handlers,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     add_sigmoid=True,
                     output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL]))
        },
        additional_metrics=None,
    )

    train_handlers = [
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x[Keys.INFO][Keys.LOSS]),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        train_handlers=train_handlers,
        amp=False,
        key_train_metric=None,
    )
    trainer.run()

    shutil.rmtree(tempdir)
Ejemplo n.º 15
0
 def val_key_metric(self, context):
     return {
         self.VAL_KEY_METRIC:
         MeanDice(output_transform=from_engine(["pred", "label"]))
     }
Ejemplo n.º 16
0
 def train_key_metric(self, context: Context):
     return {
         self.TRAIN_KEY_METRIC:
         MeanDice(output_transform=from_engine(["pred", "label"]))
     }
Ejemplo n.º 17
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    device = torch.device("cuda:0")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch["img"].to(device), batch["seg"].to(
                device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # convert the necessary metadata from batch data
    SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda batch: batch["img_meta_dict"],
        output_transform=lambda output: predict_segmentation(output[0]),
    ).attach(evaluator)
    # the model was trained by "unet_training_dict" example
    CheckpointLoader(load_path="./runs/net_checkpoint_50.pth",
                     load_dict={
                         "net": net
                     }).attach(evaluator)

    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())
    state = evaluator.run(val_loader)
    print(state)
    shutil.rmtree(tempdir)
Ejemplo n.º 18
0
    def configure(self):
        self.set_device()
        network = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)
        if self.multi_gpu:
            network = DistributedDataParallel(
                module=network,
                device_ids=[self.device],
                find_unused_parameters=False,
            )

        train_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            Spacingd(keys=("image", "label"),
                     pixdim=[1.0, 1.0, 1.0],
                     mode=["bilinear", "nearest"]),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            RandCropByPosNegLabeld(
                keys=("image", "label"),
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=("image", "label")),
        ])
        train_datalist = load_decathlon_datalist(self.data_list_file_path,
                                                 True, "training")
        if self.multi_gpu:
            train_datalist = partition_dataset(
                data=train_datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]
        train_ds = CacheDataset(
            data=train_datalist,
            transform=train_transforms,
            cache_num=32,
            cache_rate=1.0,
            num_workers=4,
        )
        train_data_loader = DataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
        )
        val_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            ToTensord(keys=("image", "label")),
        ])

        val_datalist = load_decathlon_datalist(self.data_list_file_path, True,
                                               "validation")
        val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4)
        val_data_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=4,
        )
        post_transform = Compose([
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(
                keys=["pred", "label"],
                argmax=[True, False],
                to_onehot=True,
                n_classes=2,
            ),
        ])
        # metric
        key_val_metric = {
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=self.device,
            )
        }
        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(
                save_dir=self.ckpt_dir,
                save_dict={"model": network},
                save_key_metric=True,
            ),
            TensorBoardStatsHandler(log_dir=self.ckpt_dir,
                                    output_transform=lambda x: None),
        ]
        self.eval_engine = SupervisedEvaluator(
            device=self.device,
            val_data_loader=val_data_loader,
            network=network,
            inferer=SlidingWindowInferer(
                roi_size=[160, 160, 160],
                sw_batch_size=4,
                overlap=0.5,
            ),
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            val_handlers=val_handlers,
            amp=self.amp,
        )

        optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
        loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5000,
                                                       gamma=0.1)
        train_handlers = [
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            ValidationHandler(validator=self.eval_engine,
                              interval=self.val_interval,
                              epoch_level=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                log_dir=self.ckpt_dir,
                tag_name="train_loss",
                output_transform=lambda x: x["loss"],
            ),
        ]

        self.train_engine = SupervisedTrainer(
            device=self.device,
            max_epochs=self.max_epochs,
            train_data_loader=train_data_loader,
            network=network,
            optimizer=optimizer,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=post_transform,
            key_train_metric=None,
            train_handlers=train_handlers,
            amp=self.amp,
        )

        if self.local_rank > 0:
            self.train_engine.logger.setLevel(logging.WARNING)
            self.eval_engine.logger.setLevel(logging.WARNING)
Ejemplo n.º 19
0
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

    device = torch.device('cuda:0')
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4


    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            return seg_probs, val_labels


    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator',
        batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]))
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    shutil.rmtree(tempdir)
Ejemplo n.º 20
0

def _sliding_window_processor(engine, batch):
    net.eval()
    with torch.no_grad():
        val_images, val_labels = batch['img'].to(device), batch['seg'].to(
            device)
        seg_probs = sliding_window_inference(val_images, roi_size,
                                             sw_batch_size, net)
        return seg_probs, val_labels


evaluator = Engine(_sliding_window_processor)

# add evaluation metric to the evaluator engine
MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# convert the necessary metadata from batch data
SegmentationSaver(
    output_dir='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
Ejemplo n.º 21
0
def train(data_folder=".", model_folder="runs"):
    """run a training pipeline."""

    images = sorted(glob.glob(os.path.join(data_folder, "*_ct.nii.gz")))
    labels = sorted(glob.glob(os.path.join(data_folder, "*_seg.nii.gz")))
    logging.info(
        f"training: image/label ({len(images)}) folder: {data_folder}")

    amp = True  # auto. mixed precision
    keys = ("image", "label")
    train_frac, val_frac = 0.8, 0.2
    n_train = int(train_frac * len(images)) + 1
    n_val = min(len(images) - n_train, int(val_frac * len(images)))
    logging.info(
        f"training: train {n_train} val {n_val}, folder: {data_folder}")

    train_files = [{
        keys[0]: img,
        keys[1]: seg
    } for img, seg in zip(images[:n_train], labels[:n_train])]
    val_files = [{
        keys[0]: img,
        keys[1]: seg
    } for img, seg in zip(images[-n_val:], labels[-n_val:])]

    # create a training data loader
    batch_size = 2
    logging.info(f"batch size {batch_size}")
    train_transforms = get_xforms("train", keys)
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms)
    train_loader = monai.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )

    # create a validation data loader
    val_transforms = get_xforms("val", keys)
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(
        val_ds,
        batch_size=
        1,  # image-level batch to the sliding window method, not the window-level batch
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )

    # create BasicUNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = get_net().to(device)
    max_epochs, lr, momentum = 500, 1e-4, 0.95
    logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}")
    opt = torch.optim.Adam(net.parameters(), lr=lr)

    # create evaluator (to be used to measure model quality during training
    val_post_transform = monai.transforms.Compose([
        AsDiscreted(keys=("pred", "label"),
                    argmax=(True, False),
                    to_onehot=True,
                    n_classes=2)
    ])
    val_handlers = [
        ProgressBar(),
        CheckpointSaver(save_dir=model_folder,
                        save_dict={"net": net},
                        save_key_metric=True,
                        key_metric_n_saved=3),
    ]
    evaluator = monai.engines.SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=get_inferer(),
        post_transform=val_post_transform,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=amp,
    )

    # evaluator as an event handler of the trainer
    train_handlers = [
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
    ]
    trainer = monai.engines.SupervisedTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=DiceCELoss(),
        inferer=get_inferer(),
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp,
    )
    trainer.run()
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = ImageDataset(images,
                      segs,
                      transform=imtrans,
                      seg_transform=segtrans,
                      image_only=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            seg_probs = post_trans(seg_probs)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice().attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda x: x[2],
        output_transform=lambda output: output[0],
    )
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(
        load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds,
                        batch_size=1,
                        num_workers=1,
                        pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    print(state)
Ejemplo n.º 23
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    train_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        ToTensor()
    ])
    train_segtrans = Compose([
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        ToTensor()
    ])
    val_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        Resize((96, 96, 96)),
        ToTensor()
    ])
    val_segtrans = Compose([
        AddChannel(),
        Resize((96, 96, 96)),
        ToTensor()
    ])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    loss = monai.losses.DiceLoss(do_sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device('cuda:0')

    # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={'net': net, 'opt': opt})

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not a loss value
    train_stats_handler = StatsHandler(name='trainer')
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_epochs = 1
    # Set parameters for validation
    metric_name = 'Mean_Dice'
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}

    # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net, val_metrics, device, True)


    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)


    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(patience=4,
                                  score_function=stopping_fn_from_metric(metric_name),
                                  trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along Depth axis, at every validation epoch
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch[0], batch[1]),
        output_transform=lambda output: predict_segmentation(output[0]),
        global_iter_transform=lambda x: trainer.state.epoch
    )
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler)

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
    shutil.rmtree(tempdir)
Ejemplo n.º 24
0
def train(data_folder=".", model_folder="runs", continue_training=False):
    """run a training pipeline."""

    #/== files for synthesis
    path_parent = Path(
        '/content/drive/My Drive/Datasets/covid19/COVID-19-20_augs_cea/')
    path_synthesis = Path(
        path_parent /
        'CeA_BASE_grow=1_bg=-1.00_step=-1.0_scale=-1.0_seed=1.0_ch0_1=-1_ch1_16=-1_ali_thr=0.1'
    )
    scans_syns = os.listdir(path_synthesis)
    decreasing_sequence = get_decreasing_sequence(255, splits=20)
    keys2 = ("image", "label", "synthetic_lesion")
    # READ THE SYTHETIC HEALTHY TEXTURE
    path_synthesis_old = '/content/drive/My Drive/Datasets/covid19/results/cea_synthesis/patient0/'
    texture_orig = np.load(f'{path_synthesis_old}texture.npy.npz')
    texture_orig = texture_orig.f.arr_0
    texture = texture_orig + np.abs(np.min(texture_orig)) + .07
    texture = np.pad(texture, ((100, 100), (100, 100)), mode='reflect')
    print(f'type(texture) = {type(texture)}, {np.shape(texture)}')
    #==/

    images = sorted(glob.glob(os.path.join(data_folder,
                                           "*_ct.nii.gz"))[:10])  #OMM
    labels = sorted(glob.glob(os.path.join(data_folder,
                                           "*_seg.nii.gz"))[:10])  #OMM
    logging.info(
        f"training: image/label ({len(images)}) folder: {data_folder}")

    amp = True  # auto. mixed precision
    keys = ("image", "label")
    train_frac, val_frac = 0.8, 0.2
    n_train = int(train_frac * len(images)) + 1
    n_val = min(len(images) - n_train, int(val_frac * len(images)))
    logging.info(
        f"training: train {n_train} val {n_val}, folder: {data_folder}")

    train_files = [{
        keys[0]: img,
        keys[1]: seg
    } for img, seg in zip(images[:n_train], labels[:n_train])]
    val_files = [{
        keys[0]: img,
        keys[1]: seg
    } for img, seg in zip(images[-n_val:], labels[-n_val:])]

    # create a training data loader
    batch_size = 1  # XX was 2
    logging.info(f"batch size {batch_size}")
    train_transforms = get_xforms("synthesis", keys, keys2, path_synthesis,
                                  decreasing_sequence, scans_syns, texture)
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms)
    train_loader = monai.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
        # collate_fn=pad_list_data_collate,
    )

    # create a validation data loader
    val_transforms = get_xforms("val", keys)
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(
        val_ds,
        batch_size=
        1,  # image-level batch to the sliding window method, not the window-level batch
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )

    # create BasicUNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = get_net().to(device)

    # if continue training
    if continue_training:
        ckpts = sorted(glob.glob(os.path.join(model_folder, "*.pt")))
        ckpt = ckpts[-1]
        logging.info(f"continue training using {ckpt}.")
        net.load_state_dict(torch.load(ckpt, map_location=device))

    # max_epochs, lr, momentum = 500, 1e-4, 0.95
    max_epochs, lr, momentum = 20, 1e-4, 0.95  #OMM
    logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}")
    opt = torch.optim.Adam(net.parameters(), lr=lr)

    # create evaluator (to be used to measure model quality during training
    val_post_transform = monai.transforms.Compose([
        AsDiscreted(keys=("pred", "label"),
                    argmax=(True, False),
                    to_onehot=True,
                    n_classes=2)
    ])
    val_handlers = [
        ProgressBar(),
        MetricsSaver(save_dir="./metrics_val", metrics="*"),
        CheckpointSaver(save_dir=model_folder,
                        save_dict={"net": net},
                        save_key_metric=True,
                        key_metric_n_saved=6),
    ]
    evaluator = monai.engines.SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=get_inferer(),
        post_transform=val_post_transform,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=amp,
    )

    # evaluator as an event handler of the trainer
    train_handlers = [
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        # MetricsSaver(save_dir="./metrics_train", metrics="*"),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
    ]
    trainer = monai.engines.SupervisedTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=DiceCELoss(),
        inferer=get_inferer(),
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp,
    )
    trainer.run()
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    net = DistributedDataParallel(net, device_ids=[device])

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        CheckpointLoader(
            load_path="./runs/checkpoint_epoch=4.pt",
            load_dict={"net": net},
            # config mapping to expected GPU device
            map_location={"cuda:0": f"cuda:{args.local_rank}"},
        ),
    ]
    if dist.get_rank() == 0:
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        val_handlers.extend(
            [
                StatsHandler(output_transform=lambda x: None),
                SegmentationSaver(
                    output_dir="./runs/",
                    batch_transform=lambda batch: batch["image_meta_dict"],
                    output_transform=lambda output: output["pred"],
                ),
            ]
        )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=device,
            )
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
    dist.destroy_process_group()
Ejemplo n.º 26
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=5,
                            num_workers=8,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    loss = monai.losses.DiceLoss(sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch["img"], batch["seg"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_iters = 5
    # set parameters for validation
    metric_name = "Mean_Dice"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)}

    # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration,
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch["img"], batch["seg"]),
        output_transform=lambda output: predict_segmentation(output[0]),
        global_iter_transform=lambda x: trainer.state.epoch,
    )
    evaluator.add_event_handler(event_name=Events.ITERATION_COMPLETED(every=2),
                                handler=val_tensorboard_image_handler)

    train_epochs = 5
    state = trainer.run(train_loader, train_epochs)
    print(state)
    shutil.rmtree(tempdir)
Ejemplo n.º 27
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # model file path
    model_file = glob("./runs/net_key_metric*")[0]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_file, load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
Ejemplo n.º 28
0
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )
    evaluator.run()

    return evaluator.state.best_metric
Ejemplo n.º 29
0
def run_training(train_file_list, valid_file_list, config_info):
    """
    Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks:
        * Data Preparation: Extract the filenames and prepare the training/validation processing transforms
        * Load Data: Load training and validation data to PyTorch DataLoader
        * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler
        * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation
            during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach
            on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric.
        * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop.
        * Run training: The MONAI trainer is run, performing training and validation during training.
    Args:
        train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format.
        valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format.
        config_info: dict, contains configuration parameters for sampling, network and training.
            See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields.
    """

    """
    Read input and configuration parameters
    """
    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # extract network parameters, perform checks/set defaults if not present and print them to log
    if 'seg_labels' in config_info['training'].keys():
        seg_labels = config_info['training']['seg_labels']
    else:
        seg_labels = [1]
    nr_out_channels = len(seg_labels)
    print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels))
    patch_size = config_info["training"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    spacing = config_info["training"]["spacing"]
    print("Bringing all images to spacing = {}".format(spacing))

    if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None:
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise FileNotFoundError("Cannot find model: {}".format(model_to_load))
        else:
            print("Loading model from {}".format(model_to_load))
    else:
        model_to_load = None

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))

    # set determinism if required
    if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None:
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    if seed is not None:
        print("Using determinism with seed = {}\n".format(seed))
        set_determinism(seed=seed)

    """
    Setup data output directory
    """
    out_model_dir = os.path.join(config_info['output']['out_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['out_postfix'])
    print("Saving to directory {}\n".format(out_model_dir))
    # create cache directory to store results for Persistent Dataset
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    """
    Data preparation
    """
    # Read the input files for training and validation
    print("*** Loading input data for training...")

    train_files = create_data_list_of_dictionaries(train_file_list)
    print("Number of inputs for training = {}".format(len(train_files)))

    val_files = create_data_list_of_dictionaries(valid_file_list)
    print("Number of inputs for validation = {}".format(len(val_files)))

    # Define MONAI processing transforms for the training data. This includes:
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1]
    # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M])
    # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd,
    #       RandFlipd)
    # - ToTensor: convert to pytorch tensor
    train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                        mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # Define MONAI processing transforms for the validation data
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - ToTensor: convert to pytorch tensor
    # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )

    """
    Load data 
    """
    # create training data loader
    train_ds = PersistentDataset(data=train_files, transform=train_transforms,
                                 cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=config_info['training']['batch_size_train'],
                              shuffle=True,
                              num_workers=config_info['device']['num_workers'])
    check_train_data = misc.first(train_loader)
    print("Training data tensor shapes:")
    print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape))

    # create validation data loader
    if config_info['training']['batch_size_valid'] != 1:
        raise Exception("Batch size different from 1 at validation ar currently not supported")
    val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config_info['device']['num_workers'])
    check_valid_data = misc.first(val_loader)
    print("Validation data tensor shapes (Example):")
    print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape))

    """
    Network preparation
    """
    print("*** Preparing the network ...")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # initialise the network
    net = DynUNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=nr_out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
        res_block=False,
    ).to(current_device)
    print(net)

    # define the loss function
    loss_function = choose_loss_function(nr_out_channels, config_info)

    # define the optimiser and the learning rate scheduler
    opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9
    )

    """
    MONAI evaluator
    """
    print("*** Preparing the dynUNet evaluator engine...\n")
    # val_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                output_transform=lambda x: None,
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True,
                        file_prefix='best_valid'),
    ]
    if config_info['output']['val_image_to_tensorboad']:
        val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                                    batch_transform=lambda x: (x["image"], x["label"]),
                                                    output_transform=lambda x: x["pred"], interval=2))

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2,))
            flip_inputs_2 = torch.flip(inputs, dims=(3,))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0),
        post_transform=None,
        key_val_metric={
            "Mean_dice": MeanDice(
                include_background=False,
                to_onehot_y=True,
                mutually_exclusive=True,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=False,
    )

    """
    MONAI trainer
    """
    print("*** Preparing the dynUNet trainer engine...\n")
    # train_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )

    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    # define event handlers for the trainer
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(summary_writer=writer_train,
                                log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss",
                                output_transform=lambda x: x["loss"],
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt},
                        save_final=True,
                        save_interval=2, epoch_level=True,
                        n_saved=config_info['output']['max_nr_models_saved']),
    ]
    if model_to_load is not None:
        train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt}))

    # define customized trainer
    class DynUNetTrainer(SupervisedTrainer):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)

            def _compute_loss(preds, label):
                labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]]
                return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))])

            self.network.train()
            self.optimizer.zero_grad()
            if self.amp and self.scaler is not None:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network)
                    loss = _compute_loss(predictions, targets)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.inferer(inputs, self.network)
                loss = _compute_loss(predictions, targets).mean()
                loss.backward()
                self.optimizer.step()
            return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()}

    trainer = DynUNetTrainer(
        device=current_device,
        max_epochs=config_info['training']['nr_train_epochs'],
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=False,
    )

    """
    Run training
    """
    print("*** Run training...")
    trainer.run()
    print("Done!")
Ejemplo n.º 30
0
def run_training_test(root_dir, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        RandCropByPosNegLabeld(keys=["image", "label"],
                               label_key="label",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=root_dir,
                                output_transform=lambda x: None),
        TensorBoardImageHandler(log_dir=root_dir,
                                batch_transform=lambda x:
                                (x["image"], x["label"]),
                                output_transform=lambda x: x["pred"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=root_dir,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        amp=True if amp else False,
    )
    trainer.run()

    return evaluator.state.best_metric