def test_shape(self, input_param, input_data, expected_shape): result = RandSpatialCropd(**input_param)(input_data) self.assertTupleEqual(result["img"].shape, expected_shape)
mode="bilinear", align_corners=False, ), SpatialPadd(keys=("input", "mask"), spatial_size=(1120, 1120)), RandFlipd(keys=("input", "mask"), prob=0.5, spatial_axis=1), RandAffined( keys=("input", "mask"), prob=0.5, rotate_range=np.pi / 14.4, translate_range=(70, 70), scale_range=(0.1, 0.1), as_tensor_output=False, ), RandSpatialCropd( keys=("input", "mask"), roi_size=(cfg.img_size[0], cfg.img_size[1]), random_size=False, ), RandScaleIntensityd(keys="input", factors=(-0.2, 0.2), prob=0.5), RandShiftIntensityd(keys="input", offsets=(-51, 51), prob=0.5), RandLambdad(keys="input", func=lambda x: 255 - x, prob=0.5), RandCoarseDropoutd( keys=("input", "mask"), holes=8, spatial_size=(1, 1), max_spatial_size=(102, 102), prob=0.5, ), CastToTyped(keys="input", dtype=np.float32), NormalizeIntensityd(keys="input", nonzero=False), Lambdad(keys="input", func=lambda x: x.clip(-20, 20)),
from monai.transforms import ( RandRotate, RandRotate90, RandRotate90d, RandRotated, RandSpatialCrop, RandSpatialCropd, RandZoom, RandZoomd, ) from monai.utils import set_determinism TESTS: List[Tuple] = [] TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2))) TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False)))
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted( glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted( glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted( glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted( glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted( glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted( glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip( train_images_for_dice, train_segs_for_dice)] val_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{ 'image': image_name, 'label': label_name } for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), ScaleIntensityRanged( keys=["image"], a_min=-120, a_max=170, b_min=0.0, b_max=1.0, clip=True, ), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # try to use all the available GPUs devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # loss_function = monai.losses.DiceLoss(sigmoid=True) loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda( ), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda( ), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, net) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice( val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice( train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice( test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}" .format(epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch)) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def main(hparams): print('===== INITIAL PARAMETERS =====') print('Model name: ', hparams.name) print('Batch size: ', hparams.batch_size) print('Patch size: ', hparams.patch_size) print('Epochs: ', hparams.epochs) print('Learning rate: ', hparams.learning_rate) print('Loss function: ', hparams.loss) print() ### Data collection data_dir = 'data/' print('Available directories: ', os.listdir(data_dir)) # Get paths for images and masks, organize into dictionaries images = sorted(glob.glob(data_dir + '**/*CTImg*', recursive=True)) masks = sorted(glob.glob(data_dir + '**/*Mask*', recursive=True)) data_dicts = [{'image': image_file, 'mask': mask_file} for image_file, mask_file in zip(images, masks)] # Dataset selection train_dicts = select_animals(images, masks, [12, 13, 14, 18, 20]) val_dicts = select_animals(images, masks, [25]) test_dicts = select_animals(images, masks, [27]) data_keys = ['image', 'mask'] # Data transformation data_transforms = Compose([ LoadNiftid(keys=data_keys), AddChanneld(keys=data_keys), ScaleIntensityd(keys=data_keys), CropForegroundd(keys=data_keys, source_key='image'), RandSpatialCropd( keys=data_keys, roi_size=(hparams.patch_size, hparams.patch_size, 1), random_size=False ), ]) train_transforms = Compose([ data_transforms, ToTensord(keys=data_keys) ]) val_transforms = Compose([ data_transforms, ToTensord(keys=data_keys) ]) test_transforms = Compose([ data_transforms, ToTensord(keys=data_keys) ]) # Data loaders data_loaders = { 'train': create_loader(train_dicts, batch_size=hparams.batch_size, transforms=train_transforms, shuffle=True), 'val': create_loader(val_dicts, transforms=val_transforms), 'test': create_loader(test_dicts, transforms=test_transforms) } for key in data_loaders: print(key, len(data_loaders[key])) ### Model training if hparams.loss == 'Dice': criterion = monai.losses.DiceLoss(to_onehot_y=True, do_softmax=True) elif hparams.loss == 'CrossEntropy': criterion = nn.CrossEntropyLoss() model = UNet( dimensions=2, in_channels=1, out_channels=2, channels=(64, 128, 258, 512, 1024), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, criterion=criterion, hparams=hparams, ) early_stopping = EarlyStopping('val_loss') checkpoint_callback = ModelCheckpoint logger = TensorBoardLogger('models/' + hparams.name + '/tb_logs', name=hparams.name) trainer = Trainer( check_val_every_n_epoch=5, default_save_path='models/' + hparams.name + '/checkpoints', # early_stop_callback=early_stopping, gpus=1, max_epochs=hparams.epochs, # min_epochs=10, logger=logger ) trainer.fit( model, train_dataloader=data_loaders['train'], val_dataloaders=data_loaders['val'] )
"2D", 0, SpatialCropd(KEYS, [49, 51], [390, 89]), ) ) TESTS.append( ( "SpatialCropd 3d", "3D", 0, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), ) ) TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], True, False))) TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) TESTS.append( ( "BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7, 2, 5]), ) ) TESTS.append( ( "BorderPadd 2d",
from monai.utils import set_determinism @wraps(pad_list_data_collate) def _testing_collate(x): return pad_list_data_collate(batch=x, method="end", mode="constant") TESTS: List[Tuple] = [] for pad_collate in [ _testing_collate, PadListDataCollate(method="end", mode="constant") ]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, pad_collate, Compose([ RandRotate90d("image", prob=1, max_k=3),
def test_random_shape(self, input_param, input_data, expected_shape): cropper = RandSpatialCropd(**input_param) cropper.set_random_state(seed=123) result = cropper(input_data) self.assertTupleEqual(result["img"].shape, expected_shape)
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) # use amp to accelerate training scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True total_start = time.time() train_transforms = Compose( [ # load 4 Nifti images and stack them together LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ] ) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.batch_size, shuffle=True) # validation transforms and dataset val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), ] ) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=args.batch_size, shuffle=False) # create network, loss function and optimizer if args.network == "SegResNet": model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=16, in_channels=4, out_channels=3, dropout_prob=0.0, ).to(device) else: model = UNet( spatial_dims=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = DiceFocalLoss( smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True, batch=True, ) optimizer = Novograd(model.parameters(), lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) dice_metric = DiceMetric(include_background=True, reduction="mean") dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch") post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 print(f"time elapsed before training: {time.time() - total_start}") train_start = time.time() for epoch in range(args.epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{args.epochs}") epoch_loss = train(train_loader, model, loss_function, optimizer, lr_scheduler, scaler) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, dice_metric, dice_metric_batch, post_trans ) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 if dist.get_rank() == 0: torch.save(model.state_dict(), "best_metric_model.pth") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}," f" total train time: {(time.time() - train_start):.4f}" ) dist.destroy_process_group()
def test_value(self, input_param, input_data): cropper = RandSpatialCropd(**input_param) result = cropper(input_data) roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]])
TESTS.append(( "SpatialCropd 2d", "2D", 0, SpatialCropd(KEYS, [49, 51], [390, 89]), )) TESTS.append(( "SpatialCropd 3d", "3D", 0, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), )) TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], None, True, False))) TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False))) TESTS.append(( "BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7, 2, 5]), )) TESTS.append(( "BorderPadd 2d", "2D", 0,
LoadNiftid(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # RandFlipd(keys=['image', 'label'], prob=1, spatial_axis=2), # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1, # rotate_range=(np.pi / 36, np.pi / 4, np.pi / 36)), # Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1, # sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.20, 0.20, 0.20)), # RandAdjustContrastd(keys=['image'], gamma=(0.5, 3), prob=1), # RandGaussianNoised(keys=['image'], prob=1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), # RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=1), # BorderPadd(keys=['image', 'label'],spatial_border=(16,16,0)), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), # Orientationd(keys=["image", "label"], axcodes="PLI"), ToTensord(keys=['image', 'label']) ] transform = Compose(monai_transforms) check_ds = monai.data.Dataset(data=data_dicts, transform=transform) loader = DataLoader(check_ds, batch_size=opt.batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(loader) im, seg = (check_data['image'][0], check_data['label'][0]) print(im.shape, seg.shape) vol = im[0].numpy() mask = seg[0].numpy()
"""## Setup transforms for training and validation""" train_transform = Compose([ # load 4 Nifti images and stack them together LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), RandSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64], random_size=False), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=["image", "label"]), ]) val_transform = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"),
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"Missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), RandSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64], random_size=False), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, task="Task01_BrainTumour", transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, ) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) # validation transforms and dataset val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), ToTensord(keys=["image", "label"]), ]) val_ds = BratsCacheDataset( root_dir=args.dir, task="Task01_BrainTumour", transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, ) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if dist.get_rank() == 0: # Logging for TensorBoard writer = SummaryWriter(log_dir=args.log_dir) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") if args.network == "UNet": model = UNet( dimensions=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) else: model = SegResNet(in_channels=4, out_channels=3, init_filters=16, dropout_prob=0.2).to(device) loss_function = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5, amsgrad=True) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[args.local_rank]) # start a typical PyTorch training total_epoch = args.epochs best_metric = -1000000 best_metric_epoch = -1 epoch_time = AverageMeter("Time", ":6.3f") progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ") end = time.time() print(f"Time elapsed before training: {end-total_start}") for epoch in range(total_epoch): train_loss = train(train_loader, model, loss_function, optimizer, epoch, args, device) epoch_time.update(time.time() - end) if epoch % args.print_freq == 0: progress.display(epoch) if dist.get_rank() == 0: writer.add_scalar("Loss/train", train_loss, epoch) if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, device) metrics = torch.tensor([metric, metric_tc, metric_wt, metric_et]).to(device) dist.all_reduce(metrics, op=torch.distributed.ReduceOp.SUM) if dist.get_rank() == 0: metrics = metrics / dist.get_world_size() writer.add_scalar("Mean Dice/val", metrics[0], epoch) writer.add_scalar("Mean Dice TC/val", metrics[1], epoch) writer.add_scalar("Mean Dice WT/val", metrics[2], epoch) writer.add_scalar("Mean Dice ET/val", metrics[3], epoch) if metrics[0] > best_metric: best_metric = metrics[0] best_metric_epoch = epoch + 1 print( f"current epoch: {epoch + 1} current mean dice: {metrics[0]:.4f}" f" tc: {metrics[1]:.4f} wt: {metrics[2]:.4f} et: {metrics[3]:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) end = time.time() print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}") if dist.get_rank() == 0: print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) # all processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes, # therefore, saving it in one process is sufficient torch.save(model.state_dict(), "final_model.pth") writer.flush() dist.destroy_process_group()