def test_shape(self, input_param, input_data, expected_shape):
     result = RandSpatialCropd(**input_param)(input_data)
     self.assertTupleEqual(result["img"].shape, expected_shape)
     mode="bilinear",
     align_corners=False,
 ),
 SpatialPadd(keys=("input", "mask"), spatial_size=(1120, 1120)),
 RandFlipd(keys=("input", "mask"), prob=0.5, spatial_axis=1),
 RandAffined(
     keys=("input", "mask"),
     prob=0.5,
     rotate_range=np.pi / 14.4,
     translate_range=(70, 70),
     scale_range=(0.1, 0.1),
     as_tensor_output=False,
 ),
 RandSpatialCropd(
     keys=("input", "mask"),
     roi_size=(cfg.img_size[0], cfg.img_size[1]),
     random_size=False,
 ),
 RandScaleIntensityd(keys="input", factors=(-0.2, 0.2), prob=0.5),
 RandShiftIntensityd(keys="input", offsets=(-51, 51), prob=0.5),
 RandLambdad(keys="input", func=lambda x: 255 - x, prob=0.5),
 RandCoarseDropoutd(
     keys=("input", "mask"),
     holes=8,
     spatial_size=(1, 1),
     max_spatial_size=(102, 102),
     prob=0.5,
 ),
 CastToTyped(keys="input", dtype=np.float32),
 NormalizeIntensityd(keys="input", nonzero=False),
 Lambdad(keys="input", func=lambda x: x.clip(-20, 20)),
Exemple #3
0
from monai.transforms import (
    RandRotate,
    RandRotate90,
    RandRotate90d,
    RandRotated,
    RandSpatialCrop,
    RandSpatialCropd,
    RandZoom,
    RandZoomd,
)
from monai.utils import set_determinism

TESTS: List[Tuple] = []

TESTS.append((dict, RandSpatialCropd("image",
                                     roi_size=[8, 7],
                                     random_size=True)))
TESTS.append((dict, RandRotated("image",
                                prob=1,
                                range_x=np.pi,
                                keep_size=False)))
TESTS.append((dict,
              RandZoomd("image",
                        prob=1,
                        min_zoom=1.1,
                        max_zoom=2.0,
                        keep_size=False)))
TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2)))

TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True)))
TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False)))
Exemple #4
0
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation

    # train images
    train_images = sorted(
        glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(
        glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(
        glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(
        glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(
        glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):

        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{
        'image': image_name,
        'label': label_name
    }
                        for image_name, label_name in zip(
                            train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list

    if opt.resolution is not None:
        train_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts,
                                     transform=train_transforms)
    train_loader = DataLoader(check_train,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.workers,
                              pin_memory=torch.cuda.is_available())

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts,
                                   transform=val_transforms)
    train_dice_loader = DataLoader(check_val,
                                   batch_size=1,
                                   num_workers=opt.workers,
                                   pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val,
                            batch_size=1,
                            num_workers=opt.workers,
                            pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val,
                             batch_size=1,
                             num_workers=opt.workers,
                             pin_memory=torch.cuda.is_available())

    # try to use all the available GPUs
    devices = get_devices_spec(None)

    # build the network
    net = build_net()
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # loss_function = monai.losses.DiceLoss(sigmoid=True)
    loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7)

    optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
    net_scheduler = get_scheduler(optim, opt)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(
            ), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    metric_sum = 0.0
                    metric_count = 0
                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(
                        ), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(
                            val_images, roi_size, sw_batch_size, net)
                        value = dice_metric(y_pred=val_outputs, y=val_labels)
                        metric_count += len(value)
                        metric_sum += value.item() * len(value)
                    metric = metric_sum / metric_count
                    metric_values.append(metric)
                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(
                    val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(
                    train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(
                    test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric_train, metric, metric_test,
                            best_metric, best_metric_epoch))

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation inference")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
def main(hparams):
	print('===== INITIAL PARAMETERS =====')
	print('Model name: ', hparams.name)
	print('Batch size: ', hparams.batch_size)
	print('Patch size: ', hparams.patch_size)
	print('Epochs: ', hparams.epochs)
	print('Learning rate: ', hparams.learning_rate)
	print('Loss function: ', hparams.loss)
	print()

	### Data collection
	data_dir = 'data/'
	print('Available directories: ', os.listdir(data_dir))
	# Get paths for images and masks, organize into dictionaries
	images = sorted(glob.glob(data_dir + '**/*CTImg*', recursive=True))
	masks = sorted(glob.glob(data_dir + '**/*Mask*', recursive=True))
	data_dicts = [{'image': image_file, 'mask': mask_file} for image_file, mask_file in zip(images, masks)]
	# Dataset selection
	train_dicts = select_animals(images, masks, [12, 13, 14, 18, 20])
	val_dicts = select_animals(images, masks, [25])
	test_dicts = select_animals(images, masks, [27])
	data_keys = ['image', 'mask']
	# Data transformation
	data_transforms = Compose([
	    LoadNiftid(keys=data_keys),
   	    AddChanneld(keys=data_keys),
	    ScaleIntensityd(keys=data_keys),
	    CropForegroundd(keys=data_keys, source_key='image'),
	    RandSpatialCropd(
	        keys=data_keys,
	        roi_size=(hparams.patch_size, hparams.patch_size, 1),
	        random_size=False
	    ),
	])
	train_transforms = Compose([
	    data_transforms,
  	    ToTensord(keys=data_keys)
	])
	val_transforms = Compose([
	    data_transforms,
   	    ToTensord(keys=data_keys)
	])
	test_transforms = Compose([
	    data_transforms,
   	    ToTensord(keys=data_keys)
	])
	# Data loaders
	data_loaders = {
	    'train': create_loader(train_dicts, batch_size=hparams.batch_size, transforms=train_transforms, shuffle=True),
	    'val': create_loader(val_dicts, transforms=val_transforms),
	    'test': create_loader(test_dicts, transforms=test_transforms)
	}
	for key in data_loaders:
		print(key, len(data_loaders[key]))



	### Model training
	if hparams.loss == 'Dice':
		criterion = monai.losses.DiceLoss(to_onehot_y=True, do_softmax=True)
	elif hparams.loss == 'CrossEntropy':
		criterion = nn.CrossEntropyLoss()
        
	model = UNet(
	    dimensions=2,
	    in_channels=1,
	    out_channels=2,
	    channels=(64, 128, 258, 512, 1024),
	    strides=(2, 2, 2, 2),
	    norm=monai.networks.layers.Norm.BATCH,
	    criterion=criterion,
	    hparams=hparams,
	)

	early_stopping = EarlyStopping('val_loss')
	checkpoint_callback = ModelCheckpoint
	logger = TensorBoardLogger('models/' + hparams.name + '/tb_logs', name=hparams.name)
	
	trainer = Trainer(
	    check_val_every_n_epoch=5,
	    default_save_path='models/' + hparams.name + '/checkpoints',
	#     early_stop_callback=early_stopping,
	    gpus=1,
	    max_epochs=hparams.epochs,
	#     min_epochs=10,
	    logger=logger
	)

	trainer.fit(
	    model,
	    train_dataloader=data_loaders['train'],
	    val_dataloaders=data_loaders['val']
	)
Exemple #6
0
        "2D",
        0,
        SpatialCropd(KEYS, [49, 51], [390, 89]),
    )
)

TESTS.append(
    (
        "SpatialCropd 3d",
        "3D",
        0,
        SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]),
    )
)

TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], True, False)))

TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], False, False)))

TESTS.append(
    (
        "BorderPadd 2d",
        "2D",
        0,
        BorderPadd(KEYS, [3, 7, 2, 5]),
    )
)

TESTS.append(
    (
        "BorderPadd 2d",
Exemple #7
0
from monai.utils import set_determinism


@wraps(pad_list_data_collate)
def _testing_collate(x):
    return pad_list_data_collate(batch=x, method="end", mode="constant")


TESTS: List[Tuple] = []

for pad_collate in [
        _testing_collate,
        PadListDataCollate(method="end", mode="constant")
]:
    TESTS.append((dict, pad_collate,
                  RandSpatialCropd("image", roi_size=[8, 7],
                                   random_size=True)))
    TESTS.append((dict, pad_collate,
                  RandRotated("image",
                              prob=1,
                              range_x=np.pi,
                              keep_size=False,
                              dtype=np.float64)))
    TESTS.append((dict, pad_collate,
                  RandZoomd("image",
                            prob=1,
                            min_zoom=1.1,
                            max_zoom=2.0,
                            keep_size=False)))
    TESTS.append((dict, pad_collate,
                  Compose([
                      RandRotate90d("image", prob=1, max_k=3),
Exemple #8
0
 def test_random_shape(self, input_param, input_data, expected_shape):
     cropper = RandSpatialCropd(**input_param)
     cropper.set_random_state(seed=123)
     result = cropper(input_data)
     self.assertTupleEqual(result["img"].shape, expected_shape)
Exemple #9
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    # use amp to accelerate training
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

    total_start = time.time()
    train_transforms = Compose(
        [
            # load 4 Nifti images and stack them together
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys="image"),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.0, 1.0, 1.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            EnsureTyped(keys=["image", "label"]),
            ToDeviced(keys=["image", "label"], device=device),
            RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ]
    )

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.batch_size, shuffle=True)

    # validation transforms and dataset
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys="image"),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.0, 1.0, 1.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            EnsureTyped(keys=["image", "label"]),
            ToDeviced(keys=["image", "label"], device=device),
        ]
    )
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    # ThreadDataLoader can be faster if no IO operations when caching all the data in memory
    val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=args.batch_size, shuffle=False)

    # create network, loss function and optimizer
    if args.network == "SegResNet":
        model = SegResNet(
            blocks_down=[1, 2, 2, 4],
            blocks_up=[1, 1, 1],
            init_filters=16,
            in_channels=4,
            out_channels=3,
            dropout_prob=0.0,
        ).to(device)
    else:
        model = UNet(
            spatial_dims=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)

    loss_function = DiceFocalLoss(
        smooth_nr=1e-5,
        smooth_dr=1e-5,
        squared_pred=True,
        to_onehot_y=False,
        sigmoid=True,
        batch=True,
    )
    optimizer = Novograd(model.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    print(f"time elapsed before training: {time.time() - total_start}")
    train_start = time.time()
    for epoch in range(args.epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{args.epochs}")
        epoch_loss = train(train_loader, model, loss_function, optimizer, lr_scheduler, scaler)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, dice_metric, dice_metric_batch, post_trans
            )

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                if dist.get_rank() == 0:
                    torch.save(model.state_dict(), "best_metric_model.pth")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

        print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch},"
        f" total train time: {(time.time() - train_start):.4f}"
    )
    dist.destroy_process_group()
Exemple #10
0
 def test_value(self, input_param, input_data):
     cropper = RandSpatialCropd(**input_param)
     result = cropper(input_data)
     roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size]
     np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]])
Exemple #11
0
TESTS.append((
    "SpatialCropd 2d",
    "2D",
    0,
    SpatialCropd(KEYS, [49, 51], [390, 89]),
))

TESTS.append((
    "SpatialCropd 3d",
    "3D",
    0,
    SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]),
))

TESTS.append(("RandSpatialCropd 2d", "2D", 0,
              RandSpatialCropd(KEYS, [96, 93], None, True, False)))

TESTS.append(("RandSpatialCropd 3d", "3D", 0,
              RandSpatialCropd(KEYS, [96, 93, 92], None, False, False)))

TESTS.append((
    "BorderPadd 2d",
    "2D",
    0,
    BorderPadd(KEYS, [3, 7, 2, 5]),
))

TESTS.append((
    "BorderPadd 2d",
    "2D",
    0,
        LoadNiftid(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        NormalizeIntensityd(keys=['image']),
        ScaleIntensityd(keys=['image']),
        # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
        # RandFlipd(keys=['image', 'label'], prob=1, spatial_axis=2),
        # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1,
        #             rotate_range=(np.pi / 36, np.pi / 4, np.pi / 36)),
        # Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1,
        #                sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.20, 0.20, 0.20)),
        # RandAdjustContrastd(keys=['image'], gamma=(0.5, 3), prob=1),
        # RandGaussianNoised(keys=['image'], prob=1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
        # RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=1),
        # BorderPadd(keys=['image', 'label'],spatial_border=(16,16,0)),
        RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
        # Orientationd(keys=["image", "label"], axcodes="PLI"),
        ToTensord(keys=['image', 'label'])
    ]

    transform = Compose(monai_transforms)

    check_ds = monai.data.Dataset(data=data_dicts, transform=transform)

    loader = DataLoader(check_ds, batch_size=opt.batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(loader)
    im, seg = (check_data['image'][0], check_data['label'][0])
    print(im.shape, seg.shape)

    vol = im[0].numpy()
    mask = seg[0].numpy()
Exemple #13
0
"""## Setup transforms for training and validation"""

train_transform = Compose([
    # load 4 Nifti images and stack them together
    LoadImaged(keys=["image", "label"]),
    AsChannelFirstd(keys="image"),
    ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
    Spacingd(
        keys=["image", "label"],
        pixdim=(1.5, 1.5, 2.0),
        mode=("bilinear", "nearest"),
    ),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    RandSpatialCropd(keys=["image", "label"],
                     roi_size=[128, 128, 64],
                     random_size=False),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
    RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ToTensord(keys=["image", "label"]),
])
val_transform = Compose([
    LoadImaged(keys=["image", "label"]),
    AsChannelFirstd(keys="image"),
    ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
    Spacingd(
        keys=["image", "label"],
        pixdim=(1.5, 1.5, 2.0),
        mode=("bilinear", "nearest"),
Exemple #14
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"Missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[128, 128, 64],
                         random_size=False),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        task="Task01_BrainTumour",
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        task="Task01_BrainTumour",
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
    )
    val_loader = DataLoader(val_ds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if dist.get_rank() == 0:
        # Logging for TensorBoard
        writer = SummaryWriter(log_dir=args.log_dir)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    if args.network == "UNet":
        model = UNet(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)
    else:
        model = SegResNet(in_channels=4,
                          out_channels=3,
                          init_filters=16,
                          dropout_prob=0.2).to(device)
    loss_function = DiceLoss(to_onehot_y=False,
                             sigmoid=True,
                             squared_pred=True)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5,
                                 amsgrad=True)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])

    # start a typical PyTorch training
    total_epoch = args.epochs
    best_metric = -1000000
    best_metric_epoch = -1
    epoch_time = AverageMeter("Time", ":6.3f")
    progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ")
    end = time.time()
    print(f"Time elapsed before training: {end-total_start}")
    for epoch in range(total_epoch):

        train_loss = train(train_loader, model, loss_function, optimizer,
                           epoch, args, device)
        epoch_time.update(time.time() - end)

        if epoch % args.print_freq == 0:
            progress.display(epoch)

        if dist.get_rank() == 0:
            writer.add_scalar("Loss/train", train_loss, epoch)

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, device)
            metrics = torch.tensor([metric, metric_tc, metric_wt,
                                    metric_et]).to(device)
            dist.all_reduce(metrics, op=torch.distributed.ReduceOp.SUM)

            if dist.get_rank() == 0:
                metrics = metrics / dist.get_world_size()

                writer.add_scalar("Mean Dice/val", metrics[0], epoch)
                writer.add_scalar("Mean Dice TC/val", metrics[1], epoch)
                writer.add_scalar("Mean Dice WT/val", metrics[2], epoch)
                writer.add_scalar("Mean Dice ET/val", metrics[3], epoch)
                if metrics[0] > best_metric:
                    best_metric = metrics[0]
                    best_metric_epoch = epoch + 1
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metrics[0]:.4f}"
                    f" tc: {metrics[1]:.4f} wt: {metrics[2]:.4f} et: {metrics[3]:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
        end = time.time()
        print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}")

    if dist.get_rank() == 0:
        print(
            f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
        )
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
        writer.flush()
    dist.destroy_process_group()