def test_train_timing(self):
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))
        train_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[:32], segs[:32])]
        val_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[-9:], segs[-9:])]

        device = torch.device("cuda:0")
        # define transforms for train and validation
        train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
            # change to execute transforms with Tensor data
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(keys=["image", "label"], prob=0.5),
            RandRotate90d(keys=["image", "label"],
                          prob=0.5,
                          spatial_axes=(1, 2)),
            RandZoomd(keys=["image", "label"],
                      prob=0.5,
                      min_zoom=0.8,
                      max_zoom=1.2,
                      keep_size=True),
            RandRotated(
                keys=["image", "label"],
                prob=0.5,
                range_x=np.pi / 4,
                mode=("bilinear", "nearest"),
                align_corners=True,
                dtype=np.float64,
            ),
            RandAffined(keys=["image", "label"],
                        prob=0.5,
                        rotate_range=np.pi / 2,
                        mode=("bilinear", "nearest")),
            RandGaussianNoised(keys="image", prob=0.5),
            RandStdShiftIntensityd(keys="image",
                                   prob=0.5,
                                   factors=0.05,
                                   nonzero=True),
        ])

        val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
        ])

        max_epochs = 5
        learning_rate = 2e-4
        val_interval = 1  # do validation for every epoch

        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
        train_ds = CacheDataset(data=train_files,
                                transform=train_transforms,
                                cache_rate=1.0,
                                num_workers=8)
        val_ds = CacheDataset(data=val_files,
                              transform=val_transforms,
                              cache_rate=1.0,
                              num_workers=5)
        # disable multi-workers because `ThreadDataLoader` works with multi-threads
        train_loader = ThreadDataLoader(train_ds,
                                        num_workers=0,
                                        batch_size=4,
                                        shuffle=True)
        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

        loss_function = DiceCELoss(to_onehot_y=True,
                                   softmax=True,
                                   squared_pred=True,
                                   batch=True)
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(device)

        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()

        post_pred = Compose(
            [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
        post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

        dice_metric = DiceMetric(include_background=True,
                                 reduction="mean",
                                 get_not_nans=False)

        best_metric = -1
        total_start = time.time()
        for epoch in range(max_epochs):
            epoch_start = time.time()
            print("-" * 10)
            print(f"epoch {epoch + 1}/{max_epochs}")
            model.train()
            epoch_loss = 0
            step = 0
            for batch_data in train_loader:
                step_start = time.time()
                step += 1
                optimizer.zero_grad()
                # set AMP for training
                with torch.cuda.amp.autocast():
                    outputs = model(batch_data["image"])
                    loss = loss_function(outputs, batch_data["label"])
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_loss += loss.item()
                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                      f" step time: {(time.time() - step_start):.4f}")
            epoch_loss /= step
            print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

            if (epoch + 1) % val_interval == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_loader:
                        roi_size = (96, 96, 96)
                        sw_batch_size = 4
                        # set AMP for validation
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_data["image"], roi_size, sw_batch_size,
                                model)

                        val_outputs = [
                            post_pred(i) for i in decollate_batch(val_outputs)
                        ]
                        val_labels = [
                            post_label(i)
                            for i in decollate_batch(val_data["label"])
                        ]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    if metric > best_metric:
                        best_metric = metric
                    print(
                        f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}"
                    )
            print(
                f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
            )

        total_time = time.time() - total_start
        print(
            f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}"
        )
        # test expected metrics
        self.assertGreater(best_metric, 0.95)
Example #2
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch["img"].to(device), batch["seg"].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
            val_data = decollate_batch(batch["img_meta_dict"])
            for seg_prob, data in zip(seg_probs, val_data):
                save_image(seg_prob, data)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice().attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # the model was trained by "unet_training_dict" example
    CheckpointLoader(load_path="./runs_dict/net_checkpoint_50.pt", load_dict={"net": net}).attach(evaluator)

    # sliding window inference for one image at every iteration
    val_loader = DataLoader(
        val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
    )
    state = evaluator.run(val_loader)
    print(state)
Example #3
0
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(dataset=val_ds,
                                     even_divisible=False,
                                     shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True,
                            sampler=val_sampler)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])
    # config mapping to expected GPU device
    map_location = {"cuda:0": f"cuda:{args.local_rank}"}
    # load model parameters to GPU device
    model.load_state_dict(
        torch.load("final_model.pth", map_location=map_location))

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]

            dice_metric(y_pred=val_outputs, y=val_labels)

        metric = dice_metric.aggregate().item()
        dice_metric.reset()

        if dist.get_rank() == 0:
            print("evaluation metric:", metric)
        dist.destroy_process_group()
Example #4
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    # use amp to accelerate training
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[224, 224, 144],
                         random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    train_loader = ThreadDataLoader(train_ds,
                                    num_workers=0,
                                    batch_size=args.batch_size,
                                    shuffle=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    # create network, loss function and optimizer
    if args.network == "SegResNet":
        model = SegResNet(
            blocks_down=[1, 2, 2, 4],
            blocks_up=[1, 1, 1],
            init_filters=16,
            in_channels=4,
            out_channels=3,
            dropout_prob=0.0,
        ).to(device)
    else:
        model = UNet(
            spatial_dims=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)

    loss_function = DiceFocalLoss(
        smooth_nr=1e-5,
        smooth_dr=1e-5,
        squared_pred=True,
        to_onehot_y=False,
        sigmoid=True,
        batch=True,
    )
    optimizer = Novograd(model.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.epochs)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric_batch = DiceMetric(include_background=True,
                                   reduction="mean_batch")

    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])

    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    print(f"time elapsed before training: {time.time() - total_start}")
    train_start = time.time()
    for epoch in range(args.epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{args.epochs}")
        epoch_loss = train(train_loader, model, loss_function, optimizer,
                           lr_scheduler, scaler)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, dice_metric, dice_metric_batch, post_trans)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                if dist.get_rank() == 0:
                    torch.save(model.state_dict(), "best_metric_model.pth")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

        print(
            f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
        )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch},"
        f" total train time: {(time.time() - train_start):.4f}")
    dist.destroy_process_group()
Example #5
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        EnsureTyped(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output",
                      output_ext=".png",
                      output_postfix="seg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(
        torch.load("best_metric_model_segmentation2d_dict.pth"))

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output in val_outputs:
                saver(val_output)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
Example #6
0
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceCELoss(
    to_onehot_y=True, softmax=True, squared_pred=True, batch=True
)
optimizer = Novograd(model.parameters(), learning_rate * 10)
scaler = torch.cuda.amp.GradScaler()
dice_metric = DiceMetric(
    include_background=True, reduction="mean", get_not_nans=False
)

post_pred = Compose(
    [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]
)
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
epoch_times = []
total_start = time.time()
writer = SummaryWriter(log_dir=out_dir)

with torch.autograd.profiler.emit_nvtx():
    for epoch in range(max_epochs):
        epoch_start = time.time()
Example #7
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    # the path of ixi IXI-T1 dataset
    data_path = os.sep.join(
        [".", "workspace", "data", "medical", "ixi", "IXI-T1"])
    images = [
        "IXI607-Guys-1097-T1.nii.gz",
        "IXI175-HH-1570-T1.nii.gz",
        "IXI385-HH-2078-T1.nii.gz",
        "IXI344-Guys-0905-T1.nii.gz",
        "IXI409-Guys-0960-T1.nii.gz",
        "IXI584-Guys-1129-T1.nii.gz",
        "IXI253-HH-1694-T1.nii.gz",
        "IXI092-HH-1436-T1.nii.gz",
        "IXI574-IOP-1156-T1.nii.gz",
        "IXI585-Guys-1130-T1.nii.gz",
    ]
    images = [os.sep.join([data_path, f]) for f in images]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)

    # define transforms for image
    val_transforms = Compose(
        [ScaleIntensity(),
         AddChannel(),
         Resize((96, 96, 96)),
         EnsureType()])
    # define image dataset
    val_ds = ImageDataset(image_files=images,
                          labels=labels,
                          transform=val_transforms,
                          image_only=False)
    # create DenseNet121
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.DenseNet121(spatial_dims=3,
                                          in_channels=1,
                                          out_channels=2).to(device)

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}

    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch[0], batch[1]), device, non_blocking)

    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    prediction_saver = ClassificationSaver(
        output_dir="tempdir",
        batch_transform=lambda batch: batch[2],
        output_transform=lambda output: output[0].argmax(1),
    )
    prediction_saver.attach(evaluator)

    # the model was trained by "densenet_training_array" example
    CheckpointLoader(load_path="./runs_array/net_checkpoint_20.pt",
                     load_dict={
                         "net": net
                     }).attach(evaluator)

    # create a validation data loader
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    state = evaluator.run(val_loader)
    print(state)
def evaluate(args):
    # initialize Horovod library
    hvd.init()
    # Horovod limits CPU threads to be used per worker
    torch.set_num_threads(1)

    if hvd.local_rank() == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank())
    # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent
    # issues with Infiniband implementations that are not fork-safe
    multiprocessing_context = None
    if hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods():
        multiprocessing_context = "forkserver"
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        sampler=val_sampler,
        multiprocessing_context=multiprocessing_context,
    )
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{hvd.local_rank()}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    if hvd.rank() == 0:
        # load model parameters for evaluation
        model.load_state_dict(torch.load("final_model.pth"))
    # Horovod broadcasts parameters
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            dice_metric(y_pred=val_outputs, y=val_labels)

        metric = dice_metric.aggregate().item()
        dice_metric.reset()

        if hvd.rank() == 0:
            print("evaluation metric:", metric)
Example #9
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate)
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["seg"].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = [
                        post_trans(i) for i in decollate_batch(val_outputs)
                    ]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()

                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation3d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Example #10
0
def main():
    parser = argparse.ArgumentParser(description="training")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="checkpoint full path",
    )
    parser.add_argument(
        "--factor_ram_cost",
        default=0.0,
        type=float,
        help="factor to determine RAM cost in the searched architecture",
    )
    parser.add_argument(
        "--fold",
        action="store",
        required=True,
        help="fold index in N-fold cross-validation",
    )
    parser.add_argument(
        "--json",
        action="store",
        required=True,
        help="full path of .json file",
    )
    parser.add_argument(
        "--json_key",
        action="store",
        required=True,
        help="selected key in .json data list",
    )
    parser.add_argument(
        "--local_rank",
        required=int,
        help="local process rank",
    )
    parser.add_argument(
        "--num_folds",
        action="store",
        required=True,
        help="number of folds in cross-validation",
    )
    parser.add_argument(
        "--output_root",
        action="store",
        required=True,
        help="output root",
    )
    parser.add_argument(
        "--root",
        action="store",
        required=True,
        help="data root",
    )
    args = parser.parse_args()

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if not os.path.exists(args.output_root):
        os.makedirs(args.output_root, exist_ok=True)

    amp = True
    determ = True
    factor_ram_cost = args.factor_ram_cost
    fold = int(args.fold)
    input_channels = 1
    learning_rate = 0.025
    learning_rate_arch = 0.001
    learning_rate_milestones = np.array([0.4, 0.8])
    num_images_per_batch = 1
    num_epochs = 1430  # around 20k iteration
    num_epochs_per_validation = 100
    num_epochs_warmup = 715
    num_folds = int(args.num_folds)
    num_patches_per_image = 1
    num_sw_batch_size = 6
    output_classes = 3
    overlap_ratio = 0.625
    patch_size = (96, 96, 96)
    patch_size_valid = (96, 96, 96)
    spacing = [1.0, 1.0, 1.0]

    print("factor_ram_cost", factor_ram_cost)

    # deterministic training
    if determ:
        set_determinism(seed=0)

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    # dist.barrier()
    world_size = dist.get_world_size()

    with open(args.json, "r") as f:
        json_data = json.load(f)

    split = len(json_data[args.json_key]) // num_folds
    list_train = json_data[args.json_key][:(
        split * fold)] + json_data[args.json_key][(split * (fold + 1)):]
    list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))]

    # training data
    files = []
    for _i in range(len(list_train)):
        str_img = os.path.join(args.root, list_train[_i]["image"])
        str_seg = os.path.join(args.root, list_train[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    train_files = files

    random.shuffle(train_files)

    train_files_w = train_files[:len(train_files) // 2]
    train_files_w = partition_dataset(data=train_files_w,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_w:", len(train_files_w))

    train_files_a = train_files[len(train_files) // 2:]
    train_files_a = partition_dataset(data=train_files_a,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_a:", len(train_files_a))

    # validation data
    files = []
    for _i in range(len(list_valid)):
        str_img = os.path.join(args.root, list_valid[_i]["image"])
        str_seg = os.path.join(args.root, list_valid[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    val_files = files
    val_files = partition_dataset(data=val_files,
                                  shuffle=False,
                                  num_partitions=world_size,
                                  even_divisible=False)[dist.get_rank()]
    print("val_files:", len(val_files))

    # network architecture
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)

    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)),
        CopyItemsd(keys=["label"], times=1, names=["label4crop"]),
        Lambdad(
            keys=["label4crop"],
            func=lambda x: np.concatenate(tuple([
                ndimage.binary_dilation(
                    (x == _k).astype(x.dtype), iterations=48).astype(x.dtype)
                for _k in range(output_classes)
            ]),
                                          axis=0),
            overwrite=True,
        ),
        EnsureTyped(keys=["image", "label"]),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        SpatialPadd(keys=["image", "label", "label4crop"],
                    spatial_size=patch_size,
                    mode=["reflect", "constant", "constant"]),
        RandCropByLabelClassesd(keys=["image", "label"],
                                label_key="label4crop",
                                num_classes=output_classes,
                                ratios=[
                                    1,
                                ] * output_classes,
                                spatial_size=patch_size,
                                num_samples=num_patches_per_image),
        Lambdad(keys=["label4crop"], func=lambda x: 0),
        RandRotated(keys=["image", "label"],
                    range_x=0.3,
                    range_y=0.3,
                    range_z=0.3,
                    mode=["bilinear", "nearest"],
                    prob=0.2),
        RandZoomd(keys=["image", "label"],
                  min_zoom=0.8,
                  max_zoom=1.2,
                  mode=["trilinear", "nearest"],
                  align_corners=[True, None],
                  prob=0.16),
        RandGaussianSmoothd(keys=["image"],
                            sigma_x=(0.5, 1.15),
                            sigma_y=(0.5, 1.15),
                            sigma_z=(0.5, 1.15),
                            prob=0.15),
        RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
        RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
        RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5),
        CastToTyped(keys=["image", "label"],
                    dtype=(torch.float32, torch.uint8)),
        ToTensord(keys=["image", "label"]),
    ])

    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
        EnsureTyped(keys=["image", "label"]),
        ToTensord(keys=["image", "label"])
    ])

    train_ds_a = monai.data.CacheDataset(data=train_files_a,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    train_ds_w = monai.data.CacheDataset(data=train_files_w,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0,
                                     num_workers=2)

    # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited.
    # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms)
    # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms)
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    train_loader_a = ThreadDataLoader(train_ds_a,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    train_loader_w = ThreadDataLoader(train_ds_w,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=1,
                                  shuffle=False)

    # DataLoader can be used as alternatives when ThreadDataLoader is less efficient.
    # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())

    dints_space = monai.networks.nets.TopologySearch(
        channel_mul=0.5,
        num_blocks=12,
        num_depths=4,
        use_downsample=True,
        device=device,
    )

    model = monai.networks.nets.DiNTS(
        dints_space=dints_space,
        in_channels=input_channels,
        num_classes=output_classes,
        use_downsample=True,
    )

    model = model.to(device)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    post_pred = Compose(
        [EnsureType(),
         AsDiscrete(argmax=True, to_onehot=output_classes)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)])

    # loss function
    loss_func = monai.losses.DiceCELoss(
        include_background=False,
        to_onehot_y=True,
        softmax=True,
        squared_pred=True,
        batch=True,
        smooth_nr=0.00001,
        smooth_dr=0.00001,
    )

    # optimizer
    optimizer = torch.optim.SGD(model.weight_parameters(),
                                lr=learning_rate * world_size,
                                momentum=0.9,
                                weight_decay=0.00004)
    arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)
    arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)

    print()

    if torch.cuda.device_count() > 1:
        if dist.get_rank() == 0:
            print("Let's use", torch.cuda.device_count(), "GPUs!")

        model = DistributedDataParallel(model,
                                        device_ids=[device],
                                        find_unused_parameters=True)

    if args.checkpoint != None and os.path.isfile(args.checkpoint):
        print("[info] fine-tuning pre-trained checkpoint {0:s}".format(
            args.checkpoint))
        model.load_state_dict(torch.load(args.checkpoint, map_location=device))
        torch.cuda.empty_cache()
    else:
        print("[info] training from scratch")

    # amp
    if amp:
        from torch.cuda.amp import autocast, GradScaler
        scaler = GradScaler()
        if dist.get_rank() == 0:
            print("[info] amp enabled")

    # start a typical PyTorch training
    val_interval = num_epochs_per_validation
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    idx_iter = 0
    metric_values = list()

    if dist.get_rank() == 0:
        writer = SummaryWriter(
            log_dir=os.path.join(args.output_root, "Events"))

        with open(os.path.join(args.output_root, "accuracy_history.csv"),
                  "a") as f:
            f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")

    dataloader_a_iterator = iter(train_loader_a)

    start_time = time.time()
    for epoch in range(num_epochs):
        decay = 0.5**np.sum([
            (epoch - num_epochs_warmup) /
            (num_epochs - num_epochs_warmup) > learning_rate_milestones
        ])
        lr = learning_rate * decay
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        if dist.get_rank() == 0:
            print("-" * 10)
            print(f"epoch {epoch + 1}/{num_epochs}")
            print("learning rate is set to {}".format(lr))

        model.train()
        epoch_loss = 0
        loss_torch = torch.zeros(2, dtype=torch.float, device=device)
        epoch_loss_arch = 0
        loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device)
        step = 0

        for batch_data in train_loader_w:
            step += 1
            inputs, labels = batch_data["image"].to(
                device), batch_data["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = True
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = True
            dints_space.log_alpha_a.requires_grad = False
            dints_space.log_alpha_c.requires_grad = False

            optimizer.zero_grad()

            if amp:
                with autocast():
                    outputs = model(inputs)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs, dims=[1]),
                                         1 - labels)
                    else:
                        loss = loss_func(outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
                else:
                    loss = loss_func(outputs, labels)
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item()
            loss_torch[0] += loss.item()
            loss_torch[1] += 1.0
            epoch_len = len(train_loader_w)
            idx_iter += 1

            if dist.get_rank() == 0:
                print("[{0}] ".format(str(datetime.now())[:19]) +
                      f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
                writer.add_scalar("train_loss", loss.item(),
                                  epoch_len * epoch + step)

            if epoch < num_epochs_warmup:
                continue

            try:
                sample_a = next(dataloader_a_iterator)
            except StopIteration:
                dataloader_a_iterator = iter(train_loader_a)
                sample_a = next(dataloader_a_iterator)
            inputs_search, labels_search = sample_a["image"].to(
                device), sample_a["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = False
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = False
            dints_space.log_alpha_a.requires_grad = True
            dints_space.log_alpha_c.requires_grad = True

            # linear increase topology and RAM loss
            entropy_alpha_c = torch.tensor(0.).to(device)
            entropy_alpha_a = torch.tensor(0.).to(device)
            ram_cost_full = torch.tensor(0.).to(device)
            ram_cost_usage = torch.tensor(0.).to(device)
            ram_cost_loss = torch.tensor(0.).to(device)
            topology_loss = torch.tensor(0.).to(device)

            probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True)
            entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean()
            entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \
                F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean()
            topology_loss = dints_space.get_topology_entropy(probs_a)

            ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape,
                                                           full=True)
            ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape)
            ram_cost_loss = torch.abs(factor_ram_cost -
                                      ram_cost_usage / ram_cost_full)

            arch_optimizer_a.zero_grad()
            arch_optimizer_c.zero_grad()

            combination_weights = (epoch - num_epochs_warmup) / (
                num_epochs - num_epochs_warmup)

            if amp:
                with autocast():
                    outputs_search = model(inputs_search)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                         1 - labels_search)
                    else:
                        loss = loss_func(outputs_search, labels_search)

                    loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                                    + 0.001 * topology_loss)

                scaler.scale(loss).backward()
                scaler.step(arch_optimizer_a)
                scaler.step(arch_optimizer_c)
                scaler.update()
            else:
                outputs_search = model(inputs_search)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                     1 - labels_search)
                else:
                    loss = loss_func(outputs_search, labels_search)

                loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                + 0.001 * topology_loss)

                loss.backward()
                arch_optimizer_a.step()
                arch_optimizer_c.step()

            epoch_loss_arch += loss.item()
            loss_torch_arch[0] += loss.item()
            loss_torch_arch[1] += 1.0

            if dist.get_rank() == 0:
                print(
                    "[{0}] ".format(str(datetime.now())[:19]) +
                    f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}")
                writer.add_scalar("train_loss_arch", loss.item(),
                                  epoch_len * epoch + step)

        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
        loss_torch = loss_torch.tolist()
        loss_torch_arch = loss_torch_arch.tolist()
        if dist.get_rank() == 0:
            loss_torch_epoch = loss_torch[0] / loss_torch[1]
            print(
                f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
            )

            if epoch >= num_epochs_warmup:
                loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1]
                print(
                    f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
                )

        if (epoch + 1) % val_interval == 0:
            torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                metric = torch.zeros((output_classes - 1) * 2,
                                     dtype=torch.float,
                                     device=device)
                metric_sum = 0.0
                metric_count = 0
                metric_mat = []
                val_images = None
                val_labels = None
                val_outputs = None

                _index = 0
                for val_data in val_loader:
                    val_images = val_data["image"].to(device)
                    val_labels = val_data["label"].to(device)

                    roi_size = patch_size_valid
                    sw_batch_size = num_sw_batch_size

                    if amp:
                        with torch.cuda.amp.autocast():
                            pred = sliding_window_inference(
                                val_images,
                                roi_size,
                                sw_batch_size,
                                lambda x: model(x),
                                mode="gaussian",
                                overlap=overlap_ratio,
                            )
                    else:
                        pred = sliding_window_inference(
                            val_images,
                            roi_size,
                            sw_batch_size,
                            lambda x: model(x),
                            mode="gaussian",
                            overlap=overlap_ratio,
                        )
                    val_outputs = pred

                    val_outputs = post_pred(val_outputs[0, ...])
                    val_outputs = val_outputs[None, ...]
                    val_labels = post_label(val_labels[0, ...])
                    val_labels = val_labels[None, ...]

                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)

                    print(_index + 1, "/", len(val_loader), value)

                    metric_count += len(value)
                    metric_sum += value.sum().item()
                    metric_vals = value.cpu().numpy()
                    if len(metric_mat) == 0:
                        metric_mat = metric_vals
                    else:
                        metric_mat = np.concatenate((metric_mat, metric_vals),
                                                    axis=0)

                    for _c in range(output_classes - 1):
                        val0 = torch.nan_to_num(value[0, _c], nan=0.0)
                        val1 = 1.0 - torch.isnan(value[0, 0]).float()
                        metric[2 * _c] += val0 * val1
                        metric[2 * _c + 1] += val1

                    _index += 1

                # synchronizes all processes and reduce results
                dist.barrier()
                dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
                metric = metric.tolist()
                if dist.get_rank() == 0:
                    for _c in range(output_classes - 1):
                        print(
                            "evaluation metric - class {0:d}:".format(_c + 1),
                            metric[2 * _c] / metric[2 * _c + 1])
                    avg_metric = 0
                    for _c in range(output_classes - 1):
                        avg_metric += metric[2 * _c] / metric[2 * _c + 1]
                    avg_metric = avg_metric / float(output_classes - 1)
                    print("avg_metric", avg_metric)

                    if avg_metric > best_metric:
                        best_metric = avg_metric
                        best_metric_epoch = epoch + 1
                        best_metric_iterations = idx_iter

                    node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode(
                    )
                    torch.save(
                        {
                            "node_a": node_a_d,
                            "arch_code_a": arch_code_a_d,
                            "arch_code_a_max": arch_code_a_max_d,
                            "arch_code_c": arch_code_c_d,
                            "iter_num": idx_iter,
                            "epochs": epoch + 1,
                            "best_dsc": best_metric,
                            "best_path": best_metric_iterations,
                        },
                        os.path.join(args.output_root,
                                     "search_code_" + str(idx_iter) + ".pth"),
                    )
                    print("saved new best metric model")

                    dict_file = {}
                    dict_file["best_avg_dice_score"] = float(best_metric)
                    dict_file["best_avg_dice_score_epoch"] = int(
                        best_metric_epoch)
                    dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
                    with open(os.path.join(args.output_root, "progress.yaml"),
                              "w") as out_file:
                        documents = yaml.dump(dict_file, stream=out_file)

                    print(
                        "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                        .format(epoch + 1, avg_metric, best_metric,
                                best_metric_epoch))

                    current_time = time.time()
                    elapsed_time = (current_time - start_time) / 60.0
                    with open(
                            os.path.join(args.output_root,
                                         "accuracy_history.csv"), "a") as f:
                        f.write(
                            "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n"
                            .format(epoch + 1, avg_metric, loss_torch_epoch,
                                    lr, elapsed_time, idx_iter))

                dist.barrier()

            torch.cuda.empty_cache()

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )

    if dist.get_rank() == 0:
        writer.close()

    dist.destroy_process_group()

    return
Example #11
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    # the path of ixi IXI-T1 dataset
    data_path = os.sep.join([".", "workspace", "data", "medical", "ixi", "IXI-T1"])
    images = [
        "IXI314-IOP-0889-T1.nii.gz",
        "IXI249-Guys-1072-T1.nii.gz",
        "IXI609-HH-2600-T1.nii.gz",
        "IXI173-HH-1590-T1.nii.gz",
        "IXI020-Guys-0700-T1.nii.gz",
        "IXI342-Guys-0909-T1.nii.gz",
        "IXI134-Guys-0780-T1.nii.gz",
        "IXI577-HH-2661-T1.nii.gz",
        "IXI066-Guys-0731-T1.nii.gz",
        "IXI130-HH-1528-T1.nii.gz",
        "IXI607-Guys-1097-T1.nii.gz",
        "IXI175-HH-1570-T1.nii.gz",
        "IXI385-HH-2078-T1.nii.gz",
        "IXI344-Guys-0905-T1.nii.gz",
        "IXI409-Guys-0960-T1.nii.gz",
        "IXI584-Guys-1129-T1.nii.gz",
        "IXI253-HH-1694-T1.nii.gz",
        "IXI092-HH-1436-T1.nii.gz",
        "IXI574-IOP-1156-T1.nii.gz",
        "IXI585-Guys-1130-T1.nii.gz",
    ]
    images = [os.sep.join([data_path, f]) for f in images]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)

    # define transforms
    train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()])
    val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])

    # define image dataset, data loader
    check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())
    im, label = monai.utils.misc.first(check_loader)
    print(type(im), im.shape, label)

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs_array/", "net", n_saved=10, require_empty=False)
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt}
    )

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler(output_transform=lambda x: x)
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net, val_metrics, device, True)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # create a validation data loader
    val_ds = ImageDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = ImageDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
    print(state)
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # check gpus
    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation
    # train images
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):
    
        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{'image': image_name, 'label': label_name}
                  for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{'image': image_name, 'label': label_name}
                 for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # CT HU filter
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False)

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # build the network
    if opt.network is 'nnunet':
        net = build_net()  # nn build_net
    elif opt.network is 'unetr':
        net = build_UNETR() # UneTR
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

    loss_function = monai.losses.DiceCELoss(sigmoid=True)
    torch.backends.cudnn.benchmark = opt.benchmark


    if opt.network is 'nnunet':

        optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,)
        net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9)

    elif opt.network is 'unetr':

        optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        if opt.network is 'nnunet':
            update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    # aggregate the final mean dice result
                    metric = dice_metric.aggregate().item()
                    # reset the status for next validation round
                    dice_metric.reset()

                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
                    )
                )

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")
                plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image")
                plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label")
                plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
Example #13
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")
    # try to use all the available GPUs
    devices = [torch.device("cuda" if torch.cuda.is_available() else "cpu")]
    #devices = get_devices_spec(None)
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(devices[0])

    model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth"))

    # if we have multiple GPUs, set data parallel to execute sliding window inference
    if len(devices) > 1:
        model = torch.nn.DataParallel(model, device_ids=devices)

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            meta_data = decollate_batch(val_data["img_meta_dict"])
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output, data in zip(val_outputs, meta_data):
                saver(val_output, data)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
Example #14
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))

    # define transforms for image and segmentation
    train_imtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        EnsureType(),
    ])
    train_segtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        EnsureType(),
    ])
    val_imtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        EnsureType()
    ])
    val_segtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        EnsureType()
    ])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
    check_loader = DataLoader(check_ds,
                              batch_size=10,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20],
                            train_segtrans)
    train_loader = DataLoader(train_ds,
                              batch_size=4,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = [
                        post_trans(i) for i in decollate_batch(val_outputs)
                    ]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Example #15
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(
            keys=["img", "seg"],
            label_key="seg",
            spatial_size=[96, 96, 96],
            pos=1,
            neg=1,
            num_samples=4,
        ),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(
        check_ds,
        batch_size=2,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(
        val_ds,
        batch_size=5,
        num_workers=8,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)

    # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch["img"], batch["seg"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs_dict/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED,
        handler=checkpoint_handler,
        to_save={
            "net": net,
            "opt": opt
        },
    )

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer",
                                       output_transform=lambda x: x)
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: x)
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_iters = 5
    # set parameters for validation
    metric_name = "Mean_Dice"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice()}

    post_pred = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)])

    # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(
        net,
        val_metrics,
        device,
        True,
        output_transform=lambda x, y, y_pred:
        ([post_pred(i) for i in decollate_batch(y_pred)],
         [post_label(i) for i in decollate_batch(y)]),
        prepare_batch=prepare_batch,
    )

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration,
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch["img"], batch["seg"]),
        output_transform=lambda output: output[0],
        global_iter_transform=lambda x: trainer.state.epoch,
    )
    evaluator.add_event_handler(
        event_name=Events.ITERATION_COMPLETED(every=2),
        handler=val_tensorboard_image_handler,
    )

    train_epochs = 5
    state = trainer.run(train_loader, train_epochs)
    print(state)
Example #16
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    # the path of ixi IXI-T1 dataset
    data_path = os.sep.join(
        [".", "workspace", "data", "medical", "ixi", "IXI-T1"])
    images = [
        "IXI314-IOP-0889-T1.nii.gz",
        "IXI249-Guys-1072-T1.nii.gz",
        "IXI609-HH-2600-T1.nii.gz",
        "IXI173-HH-1590-T1.nii.gz",
        "IXI020-Guys-0700-T1.nii.gz",
        "IXI342-Guys-0909-T1.nii.gz",
        "IXI134-Guys-0780-T1.nii.gz",
        "IXI577-HH-2661-T1.nii.gz",
        "IXI066-Guys-0731-T1.nii.gz",
        "IXI130-HH-1528-T1.nii.gz",
        "IXI607-Guys-1097-T1.nii.gz",
        "IXI175-HH-1570-T1.nii.gz",
        "IXI385-HH-2078-T1.nii.gz",
        "IXI344-Guys-0905-T1.nii.gz",
        "IXI409-Guys-0960-T1.nii.gz",
        "IXI584-Guys-1129-T1.nii.gz",
        "IXI253-HH-1694-T1.nii.gz",
        "IXI092-HH-1436-T1.nii.gz",
        "IXI574-IOP-1156-T1.nii.gz",
        "IXI585-Guys-1130-T1.nii.gz",
    ]
    images = [os.sep.join([data_path, f]) for f in images]

    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        dtype=np.int64)

    # Define transforms
    train_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        Resize((96, 96, 96)),
        RandRotate90(),
        EnsureType()
    ])
    val_transforms = Compose(
        [ScaleIntensity(),
         AddChannel(),
         Resize((96, 96, 96)),
         EnsureType()])

    # Define image dataset, data loader
    check_ds = ImageDataset(image_files=images,
                            labels=labels,
                            transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, label = monai.utils.misc.first(check_loader)
    print(type(im), im.shape, label)

    # create a training data loader
    train_ds = ImageDataset(image_files=images[:10],
                            labels=labels[:10],
                            transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = ImageDataset(image_files=images[-10:],
                          labels=labels[-10:],
                          transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=2,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.DenseNet121(spatial_dims=3,
                                            in_channels=1,
                                            out_channels=2).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_classification3d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_accuracy", metric, epoch + 1)
    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Example #17
0
def segment(image, label, result, weights, resolution, patch_size, network,
            gpu_ids):

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if label is not None:
        uniform_img_dimensions_internal(image, label, True)
        files = [{"image": image, "label": label}]
    else:
        files = [{"image": image}]

    # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
    original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(
        image, resolution)

    # -------------------------------

    if label is not None:
        if resolution is not None:

            val_transforms = Compose([
                LoadImaged(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image', 'label'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image', 'label'],
                         pixdim=resolution,
                         mode=('bilinear', 'nearest')),  # resolution
                SpatialPadd(keys=['image', 'label'],
                            spatial_size=patch_size,
                            method='end'),
                ToTensord(keys=['image', 'label'])
            ])
        else:

            val_transforms = Compose([
                LoadImaged(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image', 'label'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                SpatialPadd(
                    keys=['image', 'label'],
                    spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image', 'label'])
            ])

    else:
        if resolution is not None:

            val_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image'], pixdim=resolution,
                         mode=('bilinear')),  # resolution
                SpatialPadd(
                    keys=['image'], spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image'])
            ])
        else:

            val_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                SpatialPadd(
                    keys=['image'], spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image'])
            ])

    val_ds = monai.data.Dataset(data=files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=0,
                            collate_fn=list_data_collate,
                            pin_memory=False)

    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose([
        EnsureType(),
        Activations(sigmoid=True),
        AsDiscrete(threshold_values=True)
    ])

    if gpu_ids != '-1':

        # try to use all the available GPUs
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    else:
        device = torch.device("cpu")

    # build the network
    if network == 'nnunet':
        net = build_net()  # nn build_net
    elif network == 'unetr':
        net = build_UNETR()  # UneTR

    net = net.to(device)

    if gpu_ids == '-1':

        net.load_state_dict(new_state_dict_cpu(weights))

    else:

        net.load_state_dict(new_state_dict(weights))

    # define sliding window size and batch size for windows inference
    roi_size = patch_size
    sw_batch_size = 4

    net.eval()
    with torch.no_grad():

        if label is None:
            for val_data in val_loader:
                val_images = val_data["image"].to(device)
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                val_outputs = [
                    post_trans(i) for i in decollate_batch(val_outputs)
                ]

        else:
            for val_data in val_loader:
                val_images, val_labels = val_data["image"].to(
                    device), val_data["label"].to(device)
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                val_outputs = [
                    post_trans(i) for i in decollate_batch(val_outputs)
                ]
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            print("Evaluation Metric (Dice):", metric)

        result_array = val_outputs[0].squeeze().data.cpu().numpy()
        # Remove the pad if the image was smaller than the patch in some directions
        result_array = result_array[0:resampled_size[0], 0:resampled_size[1],
                                    0:resampled_size[2]]

        # resample back to the original resolution
        if resolution is not None:

            result_array_np = np.transpose(result_array, (2, 1, 0))
            result_array_temp = sitk.GetImageFromArray(result_array_np)
            result_array_temp.SetSpacing(resolution)

            # save temporary label
            writer = sitk.ImageFileWriter()
            writer.SetFileName('temp_seg.nii')
            writer.Execute(result_array_temp)

            files = [{"image": 'temp_seg.nii'}]

            files_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                Spacingd(keys=['image'],
                         pixdim=original_resolution,
                         mode=('nearest')),
                Resized(keys=['image'],
                        spatial_size=crop_shape,
                        mode=('nearest')),
            ])

            files_ds = Dataset(data=files, transform=files_transforms)
            files_loader = DataLoader(files_ds, batch_size=1, num_workers=0)

            for files_data in files_loader:
                files_images = files_data["image"]

                res = files_images.squeeze().data.numpy()

            result_array = np.rint(res)

            os.remove('./temp_seg.nii')

        # recover the cropped background before saving the image
        empty_array = np.zeros(original_shape)
        empty_array[coord1[0]:coord2[0], coord1[1]:coord2[1],
                    coord1[2]:coord2[2]] = result_array

        result_seg = from_numpy_to_itk(empty_array, image)

        # save label
        writer = sitk.ImageFileWriter()
        writer.SetFileName(result)
        writer.Execute(result_seg)
        print("Saved Result at:", str(result))
Example #18
0
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
    segtrans = Compose([AddChannel(), EnsureType()])
    val_ds = ImageDataset(images,
                          segs,
                          transform=imtrans,
                          seg_transform=segtrans,
                          image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=1,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output",
                      output_ext=".nii.gz",
                      output_postfix="seg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(
        torch.load("best_metric_model_segmentation3d_array.pth"))
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(
                device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            meta_data = decollate_batch(val_data[2])
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output, data in zip(val_outputs, meta_data):
                saver(val_output, data)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
Example #19
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    # the path of ixi IXI-T1 dataset
    data_path = os.sep.join([".", "workspace", "data", "medical", "ixi", "IXI-T1"])
    images = [
        "IXI314-IOP-0889-T1.nii.gz",
        "IXI249-Guys-1072-T1.nii.gz",
        "IXI609-HH-2600-T1.nii.gz",
        "IXI173-HH-1590-T1.nii.gz",
        "IXI020-Guys-0700-T1.nii.gz",
        "IXI342-Guys-0909-T1.nii.gz",
        "IXI134-Guys-0780-T1.nii.gz",
        "IXI577-HH-2661-T1.nii.gz",
        "IXI066-Guys-0731-T1.nii.gz",
        "IXI130-HH-1528-T1.nii.gz",
        "IXI607-Guys-1097-T1.nii.gz",
        "IXI175-HH-1570-T1.nii.gz",
        "IXI385-HH-2078-T1.nii.gz",
        "IXI344-Guys-0905-T1.nii.gz",
        "IXI409-Guys-0960-T1.nii.gz",
        "IXI584-Guys-1129-T1.nii.gz",
        "IXI253-HH-1694-T1.nii.gz",
        "IXI092-HH-1436-T1.nii.gz",
        "IXI574-IOP-1156-T1.nii.gz",
        "IXI585-Guys-1130-T1.nii.gz",
    ]
    images = [os.sep.join([data_path, f]) for f in images]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
    train_files = [{"img": img, "label": label} for img, label in zip(images[:10], labels[:10])]
    val_files = [{"img": img, "label": label} for img, label in zip(images[-10:], labels[-10:])]

    # Define transforms for image
    train_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"], spatial_size=(96, 96, 96)),
            RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
            EnsureTyped(keys=["img"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"], spatial_size=(96, 96, 96)),
            EnsureTyped(keys=["img"]),
        ]
    )
    post_pred = Compose([EnsureType(), Activations(softmax=True)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

    # Define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)
    auc_metric = ROCAUCMetric()

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["label"].to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)

                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                acc_metric = acc_value.sum().item() / len(acc_value)
                y_onehot = [post_label(i) for i in decollate_batch(y)]
                y_pred_act = [post_pred(i) for i in decollate_batch(y_pred)]
                auc_metric(y_pred_act, y_onehot)
                auc_result = auc_metric.aggregate()
                auc_metric.reset()
                del y_pred_act, y_onehot
                if acc_metric > best_metric:
                    best_metric = acc_metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_classification3d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}".format(
                        epoch + 1, acc_metric, auc_result, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()