def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) res = g(input_data) for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val assert_allclose(result, expected, rtol=1e-4, atol=1e-4)
def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) res = g(input_data) for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(torch.is_tensor(result), torch.is_tensor(expected)) if torch.is_tensor(result): np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) else: np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4)
keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), Rand3DElasticd( keys=["image", "label"], sigma_range=(0, 1), magnitude_range=(0, 1), spatial_size=None, prob=0.5, rotate_range=(0, -math.pi / 36, math.pi / 36, 0), # -15, 15 / -5, 5 shear_range=None, translate_range=None, scale_range=None, mode=("bilinear", "nearest"), padding_mode="zeros", # as_tensor_output=False ), RandGaussianNoised(keys=["image"], prob=0.5, mean=0.0, std=0.1 # allow_missing_keys=False ), #RandScaleIntensityd( # keys=["image"], # factors=0.05, # this is 10%, try 5%
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list # Need to concatenate multiple channels here if you want multichannel segmentation # Check other examples on Monai webpage. if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # # try to use all the available GPUs # devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # loss_function = monai.losses.DiceLoss(sigmoid=True) # loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) loss_function = monai.losses.DiceCELoss(sigmoid=True) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
pos=1, neg=0, num_samples=1, image_key=DataType.Image, image_threshold=0, ), # This function handles both affine and elastic deformation together # To minimise padding issues, we will elastic and affine transform first on original before any cropping Rand3DElasticd( keys=[DataType.Image, DataType.Label], sigma_range=(5, 8), # Sigma range for elastic deformation magnitude_range=(100, 200), # maginitude range for elastic deformation mode=('bilinear', 'nearest'), # Output to desired crop size spatial_size=(96, 96, 16), # Probability of augmentation - we will kepp most unchanged and augment 30% prob=0.4, # Only rotate depthwise as patient generally only varies in the depthwise direction rotate_range=(0, 0, np.pi / 15), # Set translation to 0.1 translate_range=(0.1, 0.1, 0.1), # scale in all direction by 0.1 scale_range=(0.1, 0.1, 0.1)), ManualWindowIntensity(keys=DataType.Image), # RandomWindowIntensity(keys=DataType.Image, thresholds=[1024, 512, 256, 128], prob=0.8), # Commented out, manual windowing used instead ConvertToMultiChannelBasedOnLabelsClassesd(keys=DataType.Label), ToTensord(keys=[DataType.Image, DataType.Label]), PermutateTransform(keys=[DataType.Image, DataType.Label]), ]) val_transform = Compose([
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) set_determinism(seed=0) logging.basicConfig(stream=sys.stdout, level=logging.INFO) device = torch.device(opt.gpu_id) # ------- Data loader creation ---------- # images images = sorted(glob(os.path.join(opt.images_folder, 'image*.nii'))) segs = sorted(glob(os.path.join(opt.labels_folder, 'label*.nii'))) train_files = [] val_files = [] for i in range(opt.models_ensemble): train_files.append([{ "image": img, "label": seg } for img, seg in zip( images[:(opt.split_val * i)] + images[(opt.split_val * (i + 1)):(len(images) - opt.split_val)], segs[:(opt.split_val * i)] + segs[(opt.split_val * (i + 1)):(len(images) - opt.split_val)])]) val_files.append([{ "image": img, "label": seg } for img, seg in zip( images[(opt.split_val * i):(opt.split_val * (i + 1))], segs[(opt.split_val * i):(opt.split_val * (i + 1))])]) test_files = [{ "image": img, "label": seg } for img, seg in zip(images[(len(images) - opt.split_test):len(images)], segs[( len(images) - opt.split_test):len(images)])] # ----------- Transforms list -------------- if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # ---------- Creation of DataLoaders ------------- train_dss = [ CacheDataset(data=train_files[i], transform=train_transforms) for i in range(opt.models_ensemble) ] train_loaders = [ DataLoader(train_dss[i], batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) for i in range(opt.models_ensemble) ] val_dss = [ CacheDataset(data=val_files[i], transform=val_transforms) for i in range(opt.models_ensemble) ] val_loaders = [ DataLoader(val_dss[i], batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) for i in range(opt.models_ensemble) ] test_ds = CacheDataset(data=test_files, transform=val_transforms) test_loader = DataLoader(test_ds, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) def train(index): # ---------- Build the nn-Unet network ------------ if opt.resolution is None: sizes, spacings = opt.patch_size, opt.spacing else: sizes, spacings = opt.patch_size, opt.resolution strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = monai.networks.nets.DynUNet( spatial_dims=3, in_channels=opt.in_channels, out_channels=opt.out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], res_block=True, # act=act_type, # norm=Norm.BATCH, ).to(device) from torch.autograd import Variable from torchsummaryX import summary data = Variable( torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda() out = net(data) summary(net, data) print("out size: {}".format(out.size())) # if opt.preload is not None: # net.load_state_dict(torch.load(opt.preload)) # ---------- ------------------------ ------------ optim = torch.optim.Adam(net.parameters(), lr=opt.lr) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9) loss_function = monai.losses.DiceCELoss(sigmoid=True) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loaders[index], network=net, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ ValidationHandler(validator=evaluator, interval=5, epoch_level=True), LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": optim }, save_final=True, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=opt.epochs, train_data_loader=train_loaders[index], network=net, optimizer=optim, loss_function=loss_function, inferer=SimpleInferer(), post_transform=train_post_transforms, amp=False, train_handlers=train_handlers, ) trainer.run() return net models = [train(i) for i in range(opt.models_ensemble)] # -------- Test the models --------- def ensemble_evaluate(post_transforms, models): evaluator = EnsembleEvaluator( device=device, val_data_loader=test_loader, pred_keys=opt.pred_keys, networks=models, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=post_transforms, key_val_metric={ "test_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, ) evaluator.run() mean_post_transforms = Compose([ MeanEnsembled( keys=opt.pred_keys, output_key="pred", # in this particular example, we use validation metrics as weights weights=opt.weights_models, ), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) print('Results from MeanEnsembled:') ensemble_evaluate(mean_post_transforms, models) vote_post_transforms = Compose([ Activationsd(keys=opt.pred_keys, sigmoid=True), # transform data into discrete before voting AsDiscreted(keys=opt.pred_keys, threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), VoteEnsembled(keys=opt.pred_keys, output_key="pred"), ]) print('Results from VoteEnsembled:') ensemble_evaluate(vote_post_transforms, models)
def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None): model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_" print(f"save the best model as '{model_name}' during training.") crop_size = [int(cz) for cz in crop_size.split(",")] print(f"input image crop_size: {crop_size}") # starting training set loader train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES) if np.any([cz == -1 for cz in crop_size]): # using full image train_transform = Compose([ AddChannelDict(keys="image"), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform) # when bs > 1, the loader assumes that the full image sizes are the same across the dataset train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True) else: # draw balanced foreground/background window samples according to the ground truth label train_transform = Compose([ AddChannelDict(keys="image"), SpatialPadd( keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size RandCropByPosNegLabeld(keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform ) # each dataset item is a list of windows train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate) first_sample = first(train_dataloader) print(first_sample["image"].shape) # starting validation set loader val_transform = Compose([AddChannelDict(keys="image")]) val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform) val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1) print(val_dataset[0]["image"].shape) print( f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}" ) model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat) model = flatten_sequential(model) lossweight = torch.from_numpy( np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32)) if optimizer.lower() == "rmsprop": optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4 elif optimizer.lower() == "momentum": optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning else: raise ValueError( f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')." ) # config GPipe x = first_sample["image"].float() x = torch.autograd.Variable(x.cuda()) partitions = torch.cuda.device_count() print(f"partition: {partitions}, input: {x.size()}") balance = balance_by_size(partitions, model, x) model = GPipe(model, balance, chunks=4, checkpoint="always") # config loss functions dice_loss_func = DiceLoss(softmax=True, reduction="none") # use the same pipeline and loss in # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy, # Medical Physics, 2018. focal_loss_func = FocalLoss(reduction="none") if pretrain: print(f"loading from {pretrain}.") pretrained_dict = torch.load(pretrain)["weight"] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) b_time = time.time() best_val_loss = [0] * (N_CLASSES - 1) # foreground for epoch in range(ep): model.train() trainloss = 0 for b_idx, data_dict in enumerate(train_dataloader): x_train = data_dict["image"] y_train = data_dict["label"] flagvec = data_dict["with_complete_groundtruth"] x_train = torch.autograd.Variable(x_train.cuda()) y_train = torch.autograd.Variable(y_train.cuda().float()) optimizer.zero_grad() o = model(x_train).to(0, non_blocking=True).float() loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss.backward() optimizer.step() trainloss += loss.item() if b_idx % 20 == 0: print( f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}" ) print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}") if epoch % 10 == 0: model.eval() # check validation dice val_loss = [0] * (N_CLASSES - 1) n_val = [0] * (N_CLASSES - 1) for data_dict in val_dataloader: x_val = data_dict["image"] y_val = data_dict["label"] with torch.no_grad(): x_val = torch.autograd.Variable(x_val.cuda()) o = model(x_val).to(0, non_blocking=True) loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False) val_loss = [ l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss) ] n_val = [ n + 1 if l == l else n for l, n in zip(loss[0], n_val) ] val_loss = [l / n for l, n in zip(val_loss, n_val)] print( "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss)) for c in range(1, 10): if best_val_loss[c - 1] < val_loss[c - 1]: best_val_loss[c - 1] = val_loss[c - 1] state = { "epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1] } torch.save(state, f"{model_name}" + str(c)) print( "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss)) print("total time", time.time() - b_time)
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # check gpus if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # CT HU filter # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # build the network if opt.network is 'nnunet': net = build_net() # nn build_net elif opt.network is 'unetr': net = build_UNETR() # UneTR net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) loss_function = monai.losses.DiceCELoss(sigmoid=True) torch.backends.cudnn.benchmark = opt.benchmark if opt.network is 'nnunet': optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,) net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9) elif opt.network is 'unetr': optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if opt.network is 'nnunet': update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label # val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image") plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label") plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()