Ejemplo n.º 1
0
    def test_value_class(self, input_data, expected_value):

        # same test as for compute_meandice
        vals = dict()
        vals["y_pred"] = input_data.pop("y_pred")
        vals["y"] = input_data.pop("y")
        dice_metric = DiceMetric(**input_data, reduction="none")
        result, _ = dice_metric(**vals)
        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth"))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = post_trans(val_outputs)
            value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            saver.save_batch(val_outputs, val_data[2])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
Ejemplo n.º 3
0
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            not_nans = dice_metric.not_nans.item()
            metric_count += not_nans
            metric_sum += value.item() * not_nans
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
    return metric
Ejemplo n.º 4
0
def evaluate(model, data_loader, device):
    metric = torch.zeros(8, dtype=torch.float, device=device)

    model.eval()
    with torch.no_grad():
        dice_metric = DiceMetric(include_background=True, reduction="mean")
        post_trans = Compose(
            [Activations(sigmoid=True),
             AsDiscrete(threshold_values=True)])
        for val_data in data_loader:
            val_inputs, val_labels = (
                val_data["image"].to(device, non_blocking=True),
                val_data["label"].to(device, non_blocking=True),
            )
            val_outputs = model(val_inputs)
            val_outputs = post_trans(val_outputs)
            # compute overall mean dice
            value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
            value = value.squeeze()
            metric[0] += value * not_nans
            metric[1] += not_nans
            # compute mean dice for TC
            value_tc, not_nans = dice_metric(y_pred=val_outputs[:, 0:1],
                                             y=val_labels[:, 0:1])
            value_tc = value_tc.squeeze()
            metric[2] += value_tc * not_nans
            metric[3] += not_nans
            # compute mean dice for WT
            value_wt, not_nans = dice_metric(y_pred=val_outputs[:, 1:2],
                                             y=val_labels[:, 1:2])
            value_wt = value_wt.squeeze()
            metric[4] += value_wt * not_nans
            metric[5] += not_nans
            # compute mean dice for ET
            value_et, not_nans = dice_metric(y_pred=val_outputs[:, 2:3],
                                             y=val_labels[:, 2:3])
            value_et = value_et.squeeze()
            metric[6] += value_et * not_nans
            metric[7] += not_nans

        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
        metric = metric.tolist()

    return metric[0] / metric[1], metric[2] / metric[3], metric[4] / metric[
        5], metric[6] / metric[7]
Ejemplo n.º 5
0
def evaluate(model, data_loader, device):
    metric_sum = metric_sum_tc = metric_sum_wt = metric_sum_et = 0.0
    metric_count = metric_count_tc = metric_count_wt = metric_count_et = 0

    model.eval()
    with torch.no_grad():
        dice_metric = DiceMetric(include_background=True,
                                 sigmoid=True,
                                 reduction="mean")
        for val_data in data_loader:
            val_inputs, val_labels = (
                val_data["image"].to(device, non_blocking=True),
                val_data["label"].to(device, non_blocking=True),
            )
            val_outputs = model(val_inputs)
            # compute overall mean dice
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            not_nans = dice_metric.not_nans.item()
            metric_count += not_nans
            metric_sum += value.item() * not_nans
            # compute mean dice for TC
            value_tc = dice_metric(y_pred=val_outputs[:, 0:1],
                                   y=val_labels[:, 0:1])
            not_nans = dice_metric.not_nans.item()
            metric_count_tc += not_nans
            metric_sum_tc += value_tc.item() * not_nans
            # compute mean dice for WT
            value_wt = dice_metric(y_pred=val_outputs[:, 1:2],
                                   y=val_labels[:, 1:2])
            not_nans = dice_metric.not_nans.item()
            metric_count_wt += not_nans
            metric_sum_wt += value_wt.item() * not_nans
            # compute mean dice for ET
            value_et = dice_metric(y_pred=val_outputs[:, 2:3],
                                   y=val_labels[:, 2:3])
            not_nans = dice_metric.not_nans.item()
            metric_count_et += not_nans
            metric_sum_et += value_et.item() * not_nans

        metric = metric_sum / metric_count
        metric_tc = metric_sum_tc / metric_count_tc
        metric_wt = metric_sum_wt / metric_count_wt
        metric_et = metric_sum_et / metric_count_et

    return metric, metric_tc, metric_wt, metric_et
Ejemplo n.º 6
0
class MonaiMetrics(tools.Metrics):
    dice_metric = DiceMetric(include_background=True, reduction="mean")

    def score(self, y_true, y_pred):
        metric_sum = 0.0
        metric_count = 0
        with torch.no_grad():
            for (val_true, val_pred) in zip(y_true, y_pred):
                val_true, _ = val_true
                val_pred, _ = val_pred
                value = self.dice_metric(
                    y_pred=val_pred,
                    y=val_true,
                )
                metric_count += len(value)
                metric_sum += value.item() * len(value)
        metric = metric_sum / metric_count
        return metric
Ejemplo n.º 7
0
    def __init__(
        self,
        include_background: bool = True,
        to_onehot_y: bool = False,
        mutually_exclusive: bool = False,
        sigmoid: bool = False,
        other_act: Optional[Callable] = None,
        logit_thresh: float = 0.5,
        output_transform: Callable = lambda x: x,
        device: Optional[torch.device] = None,
    ) -> None:
        """

        Args:
            include_background: whether to include dice computation on the first channel of the predicted output.
                Defaults to True.
            to_onehot_y: whether to convert the output prediction into the one-hot format. Defaults to False.
            mutually_exclusive: if True, the output prediction will be converted into a binary matrix using
                a combination of argmax and to_onehot. Defaults to False.
            sigmoid: whether to add sigmoid function to the output prediction before computing Dice.
                Defaults to False.
            other_act: callable function to replace `sigmoid` as activation layer if needed, Defaults to ``None``.
                for example: `other_act = torch.tanh`.
            logit_thresh: the threshold value to round value to 0.0 and 1.0. Defaults to None (no thresholding).
            output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
            device: device specification in case of distributed computation usage.

        See also:
            :py:meth:`monai.metrics.meandice.compute_meandice`
        """
        super().__init__(output_transform, device=device)
        self.dice = DiceMetric(
            include_background=include_background,
            to_onehot_y=to_onehot_y,
            mutually_exclusive=mutually_exclusive,
            sigmoid=sigmoid,
            other_act=other_act,
            logit_thresh=logit_thresh,
            reduction=MetricReduction.MEAN,
        )
        self._sum = 0.0
        self._num_examples = 0
Ejemplo n.º 8
0
    def __init__(
        self,
        include_background: bool = True,
        output_transform: Callable = lambda x: x,
        device: Optional[torch.device] = None,
    ) -> None:
        """

        Args:
            include_background: whether to include dice computation on the first channel of the predicted output.
                Defaults to True.
            output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
            device: device specification in case of distributed computation usage.

        See also:
            :py:meth:`monai.metrics.meandice.compute_meandice`
        """
        super().__init__(output_transform, device=device)
        self.dice = DiceMetric(
            include_background=include_background,
            reduction=MetricReduction.MEAN,
        )
        self._sum = 0.0
        self._num_examples = 0
Ejemplo n.º 9
0
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation

    # train images
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):

        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{'image': image_name, 'label': label_name}
                  for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{'image': image_name, 'label': label_name}
                 for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list
    # Need to concatenate multiple channels here if you want multichannel segmentation
    # Check other examples on Monai webpage.

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'],  gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # # try to use all the available GPUs
    # devices = get_devices_spec(None)

    # build the network
    net = build_net()
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

    # loss_function = monai.losses.DiceLoss(sigmoid=True)
    # loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7)
    loss_function = monai.losses.DiceCELoss(sigmoid=True)

    optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
    net_scheduler = get_scheduler(optim, opt)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    metric_sum = 0.0
                    metric_count = 0
                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = post_trans(val_outputs)
                        value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                        metric_count += len(value)
                        metric_sum += value.item() * len(value)
                    metric = metric_sum / metric_count
                    metric_values.append(metric)
                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
                    )
                )

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
Ejemplo n.º 10
0
    def _define_validation_metric(self):
        """Define the metric for validation function."""

        self._validation_metric = DiceMetric(include_background=True,
                                             reduction="mean")
Ejemplo n.º 11
0
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceCELoss(
    to_onehot_y=True, softmax=True, squared_pred=True, batch=True
)
optimizer = Novograd(model.parameters(), learning_rate * 10)
scaler = torch.cuda.amp.GradScaler()
dice_metric = DiceMetric(
    include_background=True, reduction="mean", get_not_nans=False
)

post_pred = Compose(
    [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]
)
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
epoch_times = []
total_start = time.time()
writer = SummaryWriter(log_dir=out_dir)
Ejemplo n.º 12
0
def evaluate(args):
    # initialize Horovod library
    hvd.init()
    # Horovod limits CPU threads to be used per worker
    torch.set_num_threads(1)

    if hvd.local_rank() == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds,
                                     num_replicas=hvd.size(),
                                     rank=hvd.rank())
    # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent
    # issues with Infiniband implementations that are not fork-safe
    multiprocessing_context = None
    if hasattr(
            mp, "_supports_context"
    ) and mp._supports_context and "forkserver" in mp.get_all_start_methods():
        multiprocessing_context = "forkserver"
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
        sampler=val_sampler,
        multiprocessing_context=multiprocessing_context,
    )
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{hvd.local_rank()}")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    if hvd.rank() == 0:
        # load model parameters for evaluation
        model.load_state_dict(torch.load("final_model.pth"))
    # Horovod broadcasts parameters
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    model.eval()
    with torch.no_grad():
        # define PyTorch Tensor to record metrics result at each GPU
        # the first value is `sum` of all dice metric, the second value is `count` of not_nan items
        metric = torch.zeros(2, dtype=torch.float, device=device)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze()
            metric[0] += value * dice_metric.not_nans
            metric[1] += dice_metric.not_nans
        # synchronizes all processes and reduce results
        print(
            f"metric in rank {hvd.rank()}: sum={metric[0].item()}, count={metric[1].item()}"
        )
        avg_metric = hvd.allreduce(metric, name="mean_dice")
        if hvd.rank() == 0:
            print(
                f"average metric: sum={avg_metric[0].item()}, count={avg_metric[1].item()}"
            )
            print("evaluation metric:", (avg_metric[0] / avg_metric[1]).item())
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    # try to use all the available GPUs
    devices = get_devices_spec(None)
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(devices[0])

    model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth"))

    # if we have multiple GPUs, set data parallel to execute sliding window inference
    if len(devices) > 1:
        model = torch.nn.DataParallel(model, device_ids=devices)

    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = post_trans(val_outputs)
            value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
# create UNet, DiceLoss and Adam optimizer
# device = torch.device("cuda:0")  # you don't need device, because Catalyst uses autoscaling
model = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

dice_metric = DiceMetric(include_background=True,
                         to_onehot_y=False,
                         sigmoid=True,
                         reduction="mean")


class MonaiSupervisedRunner(dl.SupervisedRunner):
    def forward(self, batch):
        if self.is_train_loader:
            output = {self.output_key: self.model(batch[self.input_key])}
        elif self.is_valid_loader:
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            output = {
                self.output_key:
                sliding_window_inference(batch[self.input_key], roi_size,
                                         sw_batch_size, self.model)
            }
Ejemplo n.º 15
0
def segment(image, label, result, weights, resolution, patch_size, network,
            gpu_ids):

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if label is not None:
        uniform_img_dimensions_internal(image, label, True)
        files = [{"image": image, "label": label}]
    else:
        files = [{"image": image}]

    # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
    original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(
        image, resolution)

    # -------------------------------

    if label is not None:
        if resolution is not None:

            val_transforms = Compose([
                LoadImaged(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image', 'label'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image', 'label'],
                         pixdim=resolution,
                         mode=('bilinear', 'nearest')),  # resolution
                SpatialPadd(keys=['image', 'label'],
                            spatial_size=patch_size,
                            method='end'),
                ToTensord(keys=['image', 'label'])
            ])
        else:

            val_transforms = Compose([
                LoadImaged(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image', 'label'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                SpatialPadd(
                    keys=['image', 'label'],
                    spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image', 'label'])
            ])

    else:
        if resolution is not None:

            val_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image'], pixdim=resolution,
                         mode=('bilinear')),  # resolution
                SpatialPadd(
                    keys=['image'], spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image'])
            ])
        else:

            val_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                SpatialPadd(
                    keys=['image'], spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image'])
            ])

    val_ds = monai.data.Dataset(data=files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=0,
                            collate_fn=list_data_collate,
                            pin_memory=False)

    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose([
        EnsureType(),
        Activations(sigmoid=True),
        AsDiscrete(threshold_values=True)
    ])

    if gpu_ids != '-1':

        # try to use all the available GPUs
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    else:
        device = torch.device("cpu")

    # build the network
    if network == 'nnunet':
        net = build_net()  # nn build_net
    elif network == 'unetr':
        net = build_UNETR()  # UneTR

    net = net.to(device)

    if gpu_ids == '-1':

        net.load_state_dict(new_state_dict_cpu(weights))

    else:

        net.load_state_dict(new_state_dict(weights))

    # define sliding window size and batch size for windows inference
    roi_size = patch_size
    sw_batch_size = 4

    net.eval()
    with torch.no_grad():

        if label is None:
            for val_data in val_loader:
                val_images = val_data["image"].to(device)
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                val_outputs = [
                    post_trans(i) for i in decollate_batch(val_outputs)
                ]

        else:
            for val_data in val_loader:
                val_images, val_labels = val_data["image"].to(
                    device), val_data["label"].to(device)
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                val_outputs = [
                    post_trans(i) for i in decollate_batch(val_outputs)
                ]
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            print("Evaluation Metric (Dice):", metric)

        result_array = val_outputs[0].squeeze().data.cpu().numpy()
        # Remove the pad if the image was smaller than the patch in some directions
        result_array = result_array[0:resampled_size[0], 0:resampled_size[1],
                                    0:resampled_size[2]]

        # resample back to the original resolution
        if resolution is not None:

            result_array_np = np.transpose(result_array, (2, 1, 0))
            result_array_temp = sitk.GetImageFromArray(result_array_np)
            result_array_temp.SetSpacing(resolution)

            # save temporary label
            writer = sitk.ImageFileWriter()
            writer.SetFileName('temp_seg.nii')
            writer.Execute(result_array_temp)

            files = [{"image": 'temp_seg.nii'}]

            files_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                Spacingd(keys=['image'],
                         pixdim=original_resolution,
                         mode=('nearest')),
                Resized(keys=['image'],
                        spatial_size=crop_shape,
                        mode=('nearest')),
            ])

            files_ds = Dataset(data=files, transform=files_transforms)
            files_loader = DataLoader(files_ds, batch_size=1, num_workers=0)

            for files_data in files_loader:
                files_images = files_data["image"]

                res = files_images.squeeze().data.numpy()

            result_array = np.rint(res)

            os.remove('./temp_seg.nii')

        # recover the cropped background before saving the image
        empty_array = np.zeros(original_shape)
        empty_array[coord1[0]:coord2[0], coord1[1]:coord2[1],
                    coord1[2]:coord2[2]] = result_array

        result_seg = from_numpy_to_itk(empty_array, image)

        # save label
        writer = sitk.ImageFileWriter()
        writer.SetFileName(result)
        writer.Execute(result_seg)
        print("Saved Result at:", str(result))
Ejemplo n.º 16
0
def run_training_test(root_dir, device="cuda:0", cachedataset=0):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # create a training data loader
    if cachedataset == 2:
        train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8)
    elif cachedataset == 3:
        train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms)
    else:
        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    dice_metric = DiceMetric(include_background=True, reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model))
                    value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += not_nans.item()
                    metric_sum += value.item() * not_nans.item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    print(f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}")
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
Ejemplo n.º 17
0
def segment(image, label, result, weights, resolution, patch_size):

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if label is not None:
        files = [{"image": image, "label": label}]
    else:
        files = [{"image": image}]

    if label is not None:
        if resolution is not None:
            val_transforms = Compose([
                LoadNiftid(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                NormalizeIntensityd(keys=['image']),
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image', 'label'],
                         pixdim=resolution,
                         mode=('bilinear', 'nearest')),
                ToTensord(keys=['image', 'label'])
            ])
        else:
            val_transforms = Compose([
                LoadNiftid(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                NormalizeIntensityd(keys=['image']),
                ScaleIntensityd(keys=['image']),
                ToTensord(keys=['image', 'label'])
            ])

    else:
        if resolution is not None:
            val_transforms = Compose([
                LoadNiftid(keys=['image']),
                AddChanneld(keys=['image']),
                NormalizeIntensityd(keys=['image']),
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image'], pixdim=resolution, mode='bilinear'),
                ToTensord(keys=['image'])
            ])
        else:
            val_transforms = Compose([
                LoadNiftid(keys=['image']),
                AddChanneld(keys=['image']),
                NormalizeIntensityd(keys=['image']),
                ScaleIntensityd(keys=['image']),
                ToTensord(keys=['image'])
            ])

    val_ds = monai.data.Dataset(data=files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # try to use all the available GPUs
    devices = get_devices_spec(None)
    os.environ[
        'CUDA_VISIBLE_DEVICES'] = args.gpu_ids  # Multi-gpu selector for training

    if len(devices) > 1:

        # build the network
        net = build_net()
        net = net.to(devices[0])

        net = torch.nn.DataParallel(net, device_ids=devices)
        net.load_state_dict(torch.load(weights))

    else:

        net = build_net().cuda()
        net.load_state_dict(new_state_dict(weights))

    # define sliding window size and batch size for windows inference
    roi_size = patch_size
    sw_batch_size = 4

    net.eval()
    with torch.no_grad():

        if label is None:
            for val_data in val_loader:
                val_images = val_data["image"].cuda()
                # define sliding window size and batch size for windows inference
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()

        else:
            metric_sum = 0.0
            metric_count = 0
            for val_data in val_loader:
                val_images, val_labels = val_data["image"].cuda(
                ), val_data["label"].cuda()
                # define sliding window size and batch size for windows inference
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                value = dice_metric(y_pred=val_outputs, y=val_labels)
                metric_count += len(value)
                metric_sum += value.item() * len(value)
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()

            metric = metric_sum / metric_count
            print("Evaluation Metric (Dice):", metric)

        result_array = val_outputs.squeeze().data.cpu().numpy()

        if resolution is not None:

            reader = sitk.ImageFileReader()
            reader.SetFileName(image)
            image_itk = reader.Execute()

            res = sitk.GetImageFromArray(np.transpose(result_array, (2, 1, 0)))
            res = resize(res, (sitk.GetArrayFromImage(image_itk)).shape[::-1],
                         sitk.sitkNearestNeighbor)
            res = (np.rint(sitk.GetArrayFromImage(res)))
            res = sitk.GetImageFromArray(res)
            res.SetDirection(image_itk.GetDirection())
            res.SetOrigin(image_itk.GetOrigin())
            res.SetSpacing(image_itk.GetSpacing())

        else:
            res = from_numpy_to_itk(result_array, image)

        # save label
        writer = sitk.ImageFileWriter()
        writer.SetFileName(result)
        writer.Execute(res)
        print("Saved Result at:", str(result))

        if resolution is not None:

            def dice_coeff(seg, gt):
                """
                function to calculate the dice score
                """
                seg = seg.flatten()
                gt = gt.flatten()
                dice = float(2 *
                             (gt * seg).sum()) / float(gt.sum() + seg.sum())
                return dice

            y_pred = sitk.GetArrayFromImage(sitk.ReadImage(result))
            y_true = sitk.GetArrayFromImage(sitk.ReadImage(label))

            print("Dice after resampling to original resolution:",
                  dice_coeff(y_pred, y_true))
Ejemplo n.º 18
0
def main():
    print_config()

    # Define paths for running the script
    data_dir = os.path.normpath('/to/be/defined')
    json_path = os.path.normpath('/to/be/defined')
    logdir = os.path.normpath('/to/be/defined')

    # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used
    use_pretrained = 1
    pretrained_path = os.path.normpath('/to/be/defined')

    # Training Hyper-parameters
    lr = 1e-4
    max_iterations = 30000
    eval_num = 100

    if os.path.exists(logdir) == False:
        os.mkdir(logdir)

    # Training & Validation Transform chain
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-175,
                             a_max=250,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ])

    datalist = load_decathlon_datalist(base_dir=data_dir,
                                       data_list_file_path=json_path,
                                       is_segmentation=True,
                                       data_list_key="training")

    val_files = load_decathlon_datalist(base_dir=data_dir,
                                        data_list_file_path=json_path,
                                        is_segmentation=True,
                                        data_list_key="validation")
    train_ds = CacheDataset(
        data=datalist,
        transform=train_transforms,
        cache_num=24,
        cache_rate=1.0,
        num_workers=4,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    val_ds = CacheDataset(data=val_files,
                          transform=val_transforms,
                          cache_num=6,
                          cache_rate=1.0,
                          num_workers=4)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    case_num = 0
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    img_shape = img.shape
    label_shape = label.shape
    print(f"image shape: {img_shape}, label shape: {label_shape}")

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNETR(
        in_channels=1,
        out_channels=14,
        img_size=(96, 96, 96),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="conv",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    )

    # Load ViT backbone weights into UNETR
    if use_pretrained == 1:
        print('Loading Weights from the Path {}'.format(pretrained_path))
        vit_dict = torch.load(pretrained_path)
        vit_weights = vit_dict['state_dict']

        #  Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias,
        #  conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions
        #  while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR)
        vit_weights.pop('conv3d_transpose_1.bias')
        vit_weights.pop('conv3d_transpose_1.weight')
        vit_weights.pop('conv3d_transpose.bias')
        vit_weights.pop('conv3d_transpose.weight')

        model.vit.load_state_dict(vit_weights)
        print('Pretrained Weights Succesfully Loaded !')

    elif use_pretrained == 0:
        print(
            'No weights were loaded, all weights being used are randomly initialized!'
        )

    model.to(device)

    loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
    torch.backends.cudnn.benchmark = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    post_label = AsDiscrete(to_onehot=14)
    post_pred = AsDiscrete(argmax=True, to_onehot=14)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    global_step = 0
    dice_val_best = 0.0
    global_step_best = 0
    epoch_loss_values = []
    metric_values = []

    def validation(epoch_iterator_val):
        model.eval()
        dice_vals = list()

        with torch.no_grad():
            for step, batch in enumerate(epoch_iterator_val):
                val_inputs, val_labels = (batch["image"].cuda(),
                                          batch["label"].cuda())
                val_outputs = sliding_window_inference(val_inputs,
                                                       (96, 96, 96), 4, model)
                val_labels_list = decollate_batch(val_labels)
                val_labels_convert = [
                    post_label(val_label_tensor)
                    for val_label_tensor in val_labels_list
                ]
                val_outputs_list = decollate_batch(val_outputs)
                val_output_convert = [
                    post_pred(val_pred_tensor)
                    for val_pred_tensor in val_outputs_list
                ]
                dice_metric(y_pred=val_output_convert, y=val_labels_convert)
                dice = dice_metric.aggregate().item()
                dice_vals.append(dice)
                epoch_iterator_val.set_description(
                    "Validate (%d / %d Steps) (dice=%2.5f)" %
                    (global_step, 10.0, dice))

            dice_metric.reset()

        mean_dice_val = np.mean(dice_vals)
        return mean_dice_val

    def train(global_step, train_loader, dice_val_best, global_step_best):
        model.train()
        epoch_loss = 0
        step = 0
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              dynamic_ncols=True)
        for step, batch in enumerate(epoch_iterator):
            step += 1
            x, y = (batch["image"].cuda(), batch["label"].cuda())
            logit_map = model(x)
            loss = loss_function(logit_map, y)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            epoch_iterator.set_description(
                "Training (%d / %d Steps) (loss=%2.5f)" %
                (global_step, max_iterations, loss))

            if (global_step % eval_num == 0
                    and global_step != 0) or global_step == max_iterations:
                epoch_iterator_val = tqdm(
                    val_loader,
                    desc="Validate (X / X Steps) (dice=X.X)",
                    dynamic_ncols=True)
                dice_val = validation(epoch_iterator_val)

                epoch_loss /= step
                epoch_loss_values.append(epoch_loss)
                metric_values.append(dice_val)
                if dice_val > dice_val_best:
                    dice_val_best = dice_val
                    global_step_best = global_step
                    torch.save(model.state_dict(),
                               os.path.join(logdir, "best_metric_model.pth"))
                    print(
                        "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))
                else:
                    print(
                        "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))

                plt.figure(1, (12, 6))
                plt.subplot(1, 2, 1)
                plt.title("Iteration Average Loss")
                x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
                y = epoch_loss_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.subplot(1, 2, 2)
                plt.title("Val Mean Dice")
                x = [eval_num * (i + 1) for i in range(len(metric_values))]
                y = metric_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.savefig(
                    os.path.join(logdir, 'btcv_finetune_quick_update.png'))
                plt.clf()
                plt.close(1)

            global_step += 1
        return global_step, dice_val_best, global_step_best

    while global_step < max_iterations:
        global_step, dice_val_best, global_step_best = train(
            global_step, train_loader, dice_val_best, global_step_best)
    model.load_state_dict(
        torch.load(os.path.join(logdir, "best_metric_model.pth")))

    print(f"train completed, best_metric: {dice_val_best:.4f} "
          f"at iteration: {global_step_best}")
Ejemplo n.º 19
0
def main(cfg):
    """Runs main training procedure."""

    print('Starting training...')
    print('Current working directory is:', os.getcwd())

    # fix random seeds for reproducibility
    seed_everything(seed=cfg['seed'])

    # neptune logging
    neptune.init(project_qualified_name=cfg['neptune_project_name'],
                 api_token=cfg['neptune_api_token'])

    neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg)

    num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1)
    activation = 'sigmoid' if num_classes == 1 else 'softmax2d'
    background = False if cfg['ignore_channels'] else True

    aux_params = dict(
        pooling=cfg['pooling'],  # one of 'avg', 'max'
        dropout=cfg['dropout'],  # dropout ratio, default is None
        activation='sigmoid',  # activation function, default is None
        classes=num_classes)  # define number of output labels

    # configure model
    models = {
        'unet':
        Unet(encoder_name=cfg['encoder_name'],
             encoder_weights=cfg['encoder_weights'],
             decoder_use_batchnorm=cfg['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'unetplusplus':
        UnetPlusPlus(encoder_name=cfg['encoder_name'],
                     encoder_weights=cfg['encoder_weights'],
                     decoder_use_batchnorm=cfg['use_batchnorm'],
                     classes=num_classes,
                     activation=activation,
                     aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=cfg['encoder_name'],
                      encoder_weights=cfg['encoder_weights'],
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    assert cfg['architecture'] in models.keys()
    model = models[cfg['architecture']]

    # configure loss
    losses = {
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=False,
                 batch=cfg['combine']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=False,
                            batch=cfg['combine'])
    }

    assert cfg['loss'] in losses.keys()
    loss = losses[cfg['loss']]

    # configure optimizer
    optimizers = {
        'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]),
        'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]),
        'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])])
    }

    assert cfg['optimizer'] in optimizers.keys()
    optimizer = optimizers[cfg['optimizer']]

    # configure metrics
    metrics = {
        'dice_score':
        DiceMetric(include_background=background, reduction='mean'),
        'dice_smp':
        Fscore(threshold=cfg['rounding'],
               ignore_channels=cfg['ignore_channels']),
        'iou_smp':
        IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=False,
                            batch=cfg['combine']),
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=False,
                 batch=cfg['combine']),
        'cross_entropy':
        BCELoss(reduction='mean'),
        'accuracy':
        Accuracy(ignore_channels=cfg['ignore_channels'])
    }

    assert all(m['name'] in metrics.keys() for m in cfg['metrics'])
    metrics = [(metrics[m['name']], m['name'], m['type'])
               for m in cfg['metrics']]  # tuple of (metric, name, type)

    # configure scheduler
    schedulers = {
        'steplr':
        StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5),
        'cosine':
        CosineAnnealingLR(optimizer,
                          cfg['epochs'],
                          eta_min=cfg['eta_min'],
                          last_epoch=-1)
    }

    assert cfg['scheduler'] in schedulers.keys()
    scheduler = schedulers[cfg['scheduler']]

    # configure augmentations
    train_transform = load_train_transform(transform_type=cfg['transform'],
                                           patch_size=cfg['patch_size'])
    valid_transform = load_valid_transform(patch_size=cfg['patch_size'])

    train_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                        transform=train_transform,
                                        normalize=cfg['normalize'],
                                        tissuemix=cfg['tissuemix'],
                                        probability=cfg['probability'],
                                        blending=cfg['blending'],
                                        warping=cfg['warping'],
                                        color=cfg['color'])

    valid_dataset = SegmentationDataset(df_path=cfg['valid_data'],
                                        transform=valid_transform,
                                        normalize=cfg['normalize'])

    # save intermediate augmentations
    if cfg['eval_dir']:
        default_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                              transform=None,
                                              normalize=None)

        transform_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                                transform=None,
                                                normalize=None,
                                                tissuemix=cfg['tissuemix'],
                                                probability=cfg['probability'],
                                                blending=cfg['blending'],
                                                warping=cfg['warping'],
                                                color=cfg['color'])

        for idx in range(0, min(500, len(default_dataset)), 10):
            image_input, image_mask = default_dataset[idx]
            image_input = image_input.transpose((1, 2, 0))
            image_input = image_input.astype(np.uint8)

            image_mask = image_mask.transpose(
                1, 2, 0)  # Why do we need transpose here?
            image_mask = image_mask.astype(np.uint8)
            image_mask = image_mask.squeeze()
            image_mask = image_mask * 255

            image_transform, _ = transform_dataset[idx]
            image_transform = image_transform.transpose(
                (1, 2, 0)).astype(np.uint8)

            idx_str = str(idx).zfill(3)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}a_image_input.png'),
                              image_input,
                              check_contrast=False)
            plt.imsave(os.path.join(cfg['eval_dir'],
                                    f'{idx_str}b_image_mask.png'),
                       image_mask,
                       vmin=0,
                       vmax=1)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}c_image_transform.png'),
                              image_transform,
                              check_contrast=False)

        del transform_dataset

    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=True)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=False)

    trainer = Trainer(model=model,
                      device=cfg['device'],
                      save_checkpoints=cfg['save_checkpoints'],
                      checkpoint_dir=cfg['checkpoint_dir'],
                      checkpoint_name=cfg['checkpoint_name'])

    trainer.compile(optimizer=optimizer,
                    loss=loss,
                    metrics=metrics,
                    num_classes=num_classes)

    trainer.fit(train_loader,
                valid_loader,
                epochs=cfg['epochs'],
                scheduler=scheduler,
                verbose=cfg['verbose'],
                loss_weight=cfg['loss_weight'])

    # validation inference
    best_model = model
    best_model.load_state_dict(
        torch.load(os.path.join(cfg['checkpoint_dir'],
                                cfg['checkpoint_name'])))
    best_model.to(cfg['device'])
    best_model.eval()

    # setup directory to save plots
    if os.path.isdir(cfg['plot_dir']):
        # remove existing dir and content
        shutil.rmtree(cfg['plot_dir'])
    # create absolute destination
    os.makedirs(cfg['plot_dir'])

    # valid dataset without transformations and normalization for image visualization
    valid_dataset_vis = SegmentationDataset(df_path=cfg['valid_data'],
                                            transform=valid_transform,
                                            normalize=None)

    if cfg['save_checkpoints']:
        for n in range(len(valid_dataset)):
            image_vis = valid_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose((1, 2, 0))

            image, gt_mask = valid_dataset[n]
            gt_mask = gt_mask.transpose((1, 2, 0))
            gt_mask = gt_mask.squeeze()

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = best_model.predict(x_tensor)
            pr_mask = pr_mask.cpu().numpy().round()
            pr_mask = pr_mask.squeeze()

            save_predictions(out_path=cfg['plot_dir'],
                             index=n,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask,
                             average='macro')
Ejemplo n.º 20
0
    def test_nans_class(self, params, input_data, expected_value):

        dice_metric = DiceMetric(**params)
        dice_metric(**input_data)
        result, _ = dice_metric.aggregate()
        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
Ejemplo n.º 21
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    # use amp to accelerate training
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[224, 224, 144],
                         random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    train_loader = ThreadDataLoader(train_ds,
                                    num_workers=0,
                                    batch_size=args.batch_size,
                                    shuffle=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        EnsureTyped(keys=["image", "label"]),
        ToDeviced(keys=["image", "label"], device=device),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    # create network, loss function and optimizer
    if args.network == "SegResNet":
        model = SegResNet(
            blocks_down=[1, 2, 2, 4],
            blocks_up=[1, 1, 1],
            init_filters=16,
            in_channels=4,
            out_channels=3,
            dropout_prob=0.0,
        ).to(device)
    else:
        model = UNet(
            spatial_dims=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)

    loss_function = DiceFocalLoss(
        smooth_nr=1e-5,
        smooth_dr=1e-5,
        squared_pred=True,
        to_onehot_y=False,
        sigmoid=True,
        batch=True,
    )
    optimizer = Novograd(model.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.epochs)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric_batch = DiceMetric(include_background=True,
                                   reduction="mean_batch")

    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])

    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    print(f"time elapsed before training: {time.time() - total_start}")
    train_start = time.time()
    for epoch in range(args.epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{args.epochs}")
        epoch_loss = train(train_loader, model, loss_function, optimizer,
                           lr_scheduler, scaler)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, dice_metric, dice_metric_batch, post_trans)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                if dist.get_rank() == 0:
                    torch.save(model.state_dict(), "best_metric_model.pth")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

        print(
            f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
        )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch},"
        f" total train time: {(time.time() - train_start):.4f}")
    dist.destroy_process_group()
Ejemplo n.º 22
0
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
    segtrans = Compose([AddChannel(), EnsureType()])
    val_ds = ImageDataset(images,
                          segs,
                          transform=imtrans,
                          seg_transform=segtrans,
                          image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=1,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output",
                      output_ext=".nii.gz",
                      output_postfix="seg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(
        torch.load("best_metric_model_segmentation3d_array.pth"))
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(
                device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            meta_data = decollate_batch(val_data[2])
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output, data in zip(val_outputs, meta_data):
                saver(val_output, data)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
Ejemplo n.º 23
0
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True,
                            sampler=val_sampler)
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])
    # config mapping to expected GPU device
    map_location = {"cuda:0": f"cuda:{args.local_rank}"}
    # load model parameters to GPU device
    model.load_state_dict(
        torch.load("final_model.pth", map_location=map_location))

    model.eval()
    with torch.no_grad():
        # define PyTorch Tensor to record metrics result at each GPU
        # the first value is `sum` of all dice metric, the second value is `count` of not_nan items
        metric = torch.zeros(2, dtype=torch.float, device=device)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze()
            metric[0] += value * dice_metric.not_nans
            metric[1] += dice_metric.not_nans
        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
        metric = metric.tolist()
        if dist.get_rank() == 0:
            print("evaluation metric:", metric[0] / metric[1])
        dist.destroy_process_group()
Ejemplo n.º 24
0
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(dataset=val_ds,
                                     even_divisible=False,
                                     shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True,
                            sampler=val_sampler)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])
    # config mapping to expected GPU device
    map_location = {"cuda:0": f"cuda:{args.local_rank}"}
    # load model parameters to GPU device
    model.load_state_dict(
        torch.load("final_model.pth", map_location=map_location))

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]

            dice_metric(y_pred=val_outputs, y=val_labels)

        metric = dice_metric.aggregate().item()
        dice_metric.reset()

        if dist.get_rank() == 0:
            print("evaluation metric:", metric)
        dist.destroy_process_group()
Ejemplo n.º 25
0
def evaluate(gpu, args):
    # initialize the distributed evaluation process, every GPU runs in a process,
    # so the process rank is (node index x GPU count of 1 node + GPU index)
    rank = args.node * args.gpus + gpu
    dist.init_process_group(backend="nccl",
                            init_method="env://",
                            world_size=args.world_size,
                            rank=rank)

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds,
                                     num_replicas=args.world_size,
                                     rank=rank)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
        sampler=val_sampler,
    )
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{gpu}")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[gpu])
    # config mapping to expected GPU device
    map_location = {"cuda:0": f"cuda:{gpu}"}
    # load model parameters to GPU device
    model.load_state_dict(
        torch.load("final_model.pth", map_location=map_location))

    model.eval()
    with torch.no_grad():
        # define PyTorch Tensor to record metrics result at each GPU
        # the first value is `sum` of all dice metric, the second value is `count` of not_nan items
        metric = torch.zeros(2, dtype=torch.float, device=device)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze()
            metric[0] += value * dice_metric.not_nans
            metric[1] += dice_metric.not_nans
        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
        metric = metric.tolist()
        if rank == 0:
            print("evaluation metric:", metric[0] / metric[1])
        dist.destroy_process_group()
    def test_train_timing(self):
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))
        train_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[:32], segs[:32])]
        val_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[-9:], segs[-9:])]

        device = torch.device("cuda:0")
        # define transforms for train and validation
        train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
            # change to execute transforms with Tensor data
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(keys=["image", "label"], prob=0.5),
            RandRotate90d(keys=["image", "label"],
                          prob=0.5,
                          spatial_axes=(1, 2)),
            RandZoomd(keys=["image", "label"],
                      prob=0.5,
                      min_zoom=0.8,
                      max_zoom=1.2,
                      keep_size=True),
            RandRotated(
                keys=["image", "label"],
                prob=0.5,
                range_x=np.pi / 4,
                mode=("bilinear", "nearest"),
                align_corners=True,
                dtype=np.float64,
            ),
            RandAffined(keys=["image", "label"],
                        prob=0.5,
                        rotate_range=np.pi / 2,
                        mode=("bilinear", "nearest")),
            RandGaussianNoised(keys="image", prob=0.5),
            RandStdShiftIntensityd(keys="image",
                                   prob=0.5,
                                   factors=0.05,
                                   nonzero=True),
        ])

        val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
        ])

        max_epochs = 5
        learning_rate = 2e-4
        val_interval = 1  # do validation for every epoch

        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
        train_ds = CacheDataset(data=train_files,
                                transform=train_transforms,
                                cache_rate=1.0,
                                num_workers=8)
        val_ds = CacheDataset(data=val_files,
                              transform=val_transforms,
                              cache_rate=1.0,
                              num_workers=5)
        # disable multi-workers because `ThreadDataLoader` works with multi-threads
        train_loader = ThreadDataLoader(train_ds,
                                        num_workers=0,
                                        batch_size=4,
                                        shuffle=True)
        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

        loss_function = DiceCELoss(to_onehot_y=True,
                                   softmax=True,
                                   squared_pred=True,
                                   batch=True)
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(device)

        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()

        post_pred = Compose(
            [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
        post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

        dice_metric = DiceMetric(include_background=True,
                                 reduction="mean",
                                 get_not_nans=False)

        best_metric = -1
        total_start = time.time()
        for epoch in range(max_epochs):
            epoch_start = time.time()
            print("-" * 10)
            print(f"epoch {epoch + 1}/{max_epochs}")
            model.train()
            epoch_loss = 0
            step = 0
            for batch_data in train_loader:
                step_start = time.time()
                step += 1
                optimizer.zero_grad()
                # set AMP for training
                with torch.cuda.amp.autocast():
                    outputs = model(batch_data["image"])
                    loss = loss_function(outputs, batch_data["label"])
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_loss += loss.item()
                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                      f" step time: {(time.time() - step_start):.4f}")
            epoch_loss /= step
            print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

            if (epoch + 1) % val_interval == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_loader:
                        roi_size = (96, 96, 96)
                        sw_batch_size = 4
                        # set AMP for validation
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_data["image"], roi_size, sw_batch_size,
                                model)

                        val_outputs = [
                            post_pred(i) for i in decollate_batch(val_outputs)
                        ]
                        val_labels = [
                            post_label(i)
                            for i in decollate_batch(val_data["label"])
                        ]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    if metric > best_metric:
                        best_metric = metric
                    print(
                        f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}"
                    )
            print(
                f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
            )

        total_time = time.time() - total_start
        print(
            f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}"
        )
        # test expected metrics
        self.assertGreater(best_metric, 0.95)
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
#loss_function = DiceLoss(to_onehot_y=True, softmax=True)
#optimizer = torch.optim.Adam(model.parameters(), 1e-4)

loss_function = DiceCELoss(include_background=True,
                           to_onehot_y=True,
                           softmax=True,
                           lambda_dice=0.5,
                           lambda_ce=0.5)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       'max',
                                                       factor=0.5)  ##
"""## Execute a typical PyTorch training process"""

epoch_num = 300
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
post_label = AsDiscrete(to_onehot=True, n_classes=2)

for epoch in range(epoch_num):
Ejemplo n.º 28
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # # create a temporary directory and 40 random image, mask pairs
    if (not os.path.exists(tempdir)):
        print(f"generating synthetic data to {tempdir}")
        for i in range(40):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=3,
                                           channel_dim=-1)

            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
    else:
        print(f"Found data in {tempdir}")

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[64, 64, 64],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.35, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate)
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    model_path = "best_metric_model_segmentation3d_dict.pth"
    if (os.path.exists(model_path)):
        model.load_state_dict(torch.load(model_path))
        print(f"Loaded model from file '{model_path}'")

    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    num_epochs = 10
    for epoch in range(num_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{num_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["seg"].to(device)
                    roi_size = (64, 64, 64)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    value = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_path)
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Ejemplo n.º 29
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    train_segtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    val_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        ToTensor()
    ])
    val_segtrans = Compose(
        [LoadImage(image_only=True),
         AddChannel(), ToTensor()])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
    check_loader = DataLoader(check_ds,
                              batch_size=10,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20],
                            train_segtrans)
    train_loader = DataLoader(train_ds,
                              batch_size=4,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = post_trans(val_outputs)
                    value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Ejemplo n.º 30
0
def evaluate(args):
    # initialize Horovod library
    hvd.init()
    # Horovod limits CPU threads to be used per worker
    torch.set_num_threads(1)

    if hvd.local_rank() == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank())
    # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent
    # issues with Infiniband implementations that are not fork-safe
    multiprocessing_context = None
    if hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods():
        multiprocessing_context = "forkserver"
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        sampler=val_sampler,
        multiprocessing_context=multiprocessing_context,
    )
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{hvd.local_rank()}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    if hvd.rank() == 0:
        # load model parameters for evaluation
        model.load_state_dict(torch.load("final_model.pth"))
    # Horovod broadcasts parameters
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            dice_metric(y_pred=val_outputs, y=val_labels)

        metric = dice_metric.aggregate().item()
        dice_metric.reset()

        if hvd.rank() == 0:
            print("evaluation metric:", metric)