def test_correct_results(self, _, keys, mean, std): gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) gaussian_fn.set_random_state(seed) noised = gaussian_fn({"img": self.imt}) np.random.seed(seed) np.random.random() expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) np.testing.assert_allclose(expected, noised["img"], atol=1e-5, rtol=1e-5)
def test_numpy_or_torch(keys, mean, std, imt): gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) gaussian_fn.set_random_state(seed) noised = gaussian_fn({k: imt for k in keys}) np.random.seed(seed) np.random.random() for k in keys: expected = imt + np.random.normal( mean, np.random.uniform(0, std), size=imt.shape) np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)
def test_correct_results(self, _, im_type, keys, mean, std): gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64) gaussian_fn.set_random_state(seed) im = im_type(self.imt) noised = gaussian_fn({k: im for k in keys}) np.random.seed(seed) # simulate the randomize() of transform np.random.random() noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) for k in keys: expected = self.imt + noise if isinstance(noised[k], torch.Tensor): noised[k] = noised[k].cpu() np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)
def get_xforms(args, mode="train", keys=("image", "label")): """returns a composed transform for train/val/infer.""" xforms = [ LoadNiftid(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[:len(keys)]), ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), ] if mode == "train": xforms.extend([ SpatialPadd(keys, spatial_size=(args.patch_size, args.patch_size, -1), mode="reflect"), # ensure at least 192x192 RandAffined( keys, prob=0.15, rotate_range=(-0.05, 0.05), scale_range=(-0.1, 0.1), mode=("bilinear", "nearest"), as_tensor_output=False, ), RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(args.patch_size, args.patch_size, args.n_slice), num_samples=3), RandGaussianNoised(keys[0], prob=0.15, std=0.01), RandFlipd(keys, spatial_axis=0, prob=0.5), RandFlipd(keys, spatial_axis=1, prob=0.5), RandFlipd(keys, spatial_axis=2, prob=0.5), ]) dtype = (np.float32, np.uint8) if mode == "val": dtype = (np.float32, np.uint8) if mode == "infer": dtype = (np.float32, ) xforms.extend([CastToTyped(keys, dtype=dtype), ToTensord(keys)]) return monai.transforms.Compose(xforms)
def get_xforms_with_synthesis(mode="synthesis", keys=("image", "label"), keys2=("image", "label", "synthetic_lesion")): """returns a composed transform for train/val/infer.""" xforms = [ LoadImaged(keys), AddChanneld(keys), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]), ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), CopyItemsd(keys,1, names=['image_1', 'label_1']), ] if mode == "synthesis": xforms.extend([ SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"), # ensure at least 192x192 RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=3), TransCustom(keys, path_synthesis, read_cea_aug_slice2, pseudo_healthy_with_texture, scans_syns, decreasing_sequence, GEN=15, POST_PROCESS=True, mask_outer_ring=True, new_value=.5), RandAffined( # keys, keys2, prob=0.15, rotate_range=(0.05, 0.05, None), # 3 parameters control the transform on 3 dimensions scale_range=(0.1, 0.1, None), mode=("bilinear", "nearest", "bilinear"), # mode=("bilinear", "nearest"), as_tensor_output=False ), RandGaussianNoised((keys2[0],keys2[2]), prob=0.15, std=0.01), # RandGaussianNoised(keys[0], prob=0.15, std=0.01), RandFlipd(keys, spatial_axis=0, prob=0.5), RandFlipd(keys, spatial_axis=1, prob=0.5), RandFlipd(keys, spatial_axis=2, prob=0.5), TransCustom2(0.333) ]) dtype = (np.float32, np.uint8) # dtype = (np.float32, np.uint8, np.float32) xforms.extend([CastToTyped(keys, dtype=dtype)]) return monai.transforms.Compose(xforms)
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
keys=["image", "label"], sigma_range=(0, 1), magnitude_range=(0, 1), spatial_size=None, prob=0.5, rotate_range=(0, -math.pi / 36, math.pi / 36, 0), # -15, 15 / -5, 5 shear_range=None, translate_range=None, scale_range=None, mode=("bilinear", "nearest"), padding_mode="zeros", # as_tensor_output=False ), RandGaussianNoised(keys=["image"], prob=0.5, mean=0.0, std=0.1 # allow_missing_keys=False ), #RandScaleIntensityd( # keys=["image"], # factors=0.05, # this is 10%, try 5% # prob=0.5 #), #RandGaussianSmoothd( # keys=["image"], # sigma_x=(0.25, 1.5), # sigma_y=(0.25, 1.5), # sigma_z=(0.25, 1.5), # prob=0.5, # approx='erf' # allow_missing_keys=False
def infer(data_folder=".", model_folder="runs", prediction_folder="output"): """ run inference, the output folder will be "./output" """ ckpts = sorted(glob.glob(os.path.join(model_folder, "*.pt"))) ckpt = ckpts[-1] for x in ckpts: logging.info(f"available model file: {x}.") logging.info("----") logging.info(f"using {ckpt}.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = get_net().to(device) net.load_state_dict(torch.load(ckpt, map_location=device)) net.eval() image_folder = os.path.abspath(data_folder) images = sorted(glob.glob(os.path.join(image_folder, "*_ct.nii.gz"))) logging.info(f"infer: image ({len(images)}) folder: {data_folder}") infer_files = [{"image": img} for img in images] keys = ("image", ) infer_transforms = get_xforms("infer", keys) infer_ds = monai.data.Dataset(data=infer_files, transform=infer_transforms) infer_loader = monai.data.DataLoader( infer_ds, batch_size= 1, # image-level batch to the sliding window method, not the window-level batch num_workers=2, pin_memory=torch.cuda.is_available(), ) inferer = get_inferer() saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest") with torch.no_grad(): for infer_data in infer_loader: logging.info( f"segmenting {infer_data['image_meta_dict']['filename_or_obj']}" ) preds = inferer(infer_data[keys[0]].to(device), net) n = 1.0 for _ in range(4): # test time augmentations _img = RandGaussianNoised(keys[0], prob=1.0, std=0.01)(infer_data)[keys[0]] pred = inferer(_img.to(device), net) preds = preds + pred n = n + 1.0 for dims in [[2], [3]]: flip_pred = inferer(torch.flip(_img.to(device), dims=dims), net) pred = torch.flip(flip_pred, dims=dims) preds = preds + pred n = n + 1.0 preds = preds / n preds = (preds.argmax(dim=1, keepdims=True)).float() saver.save_batch(preds, infer_data["image_meta_dict"]) # copy the saved segmentations into the required folder structure for submission submission_dir = os.path.join(prediction_folder, "to_submit") if not os.path.exists(submission_dir): os.makedirs(submission_dir) files = glob.glob(os.path.join(prediction_folder, "volume*", "*.nii.gz")) for f in files: new_name = os.path.basename(f) new_name = new_name[len("volume-covid19-A-0"):] new_name = new_name[:-len("_ct_seg.nii.gz")] + ".nii.gz" to_name = os.path.join(submission_dir, new_name) shutil.copy(f, to_name) logging.info(f"predictions copied to {submission_dir}.")
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list # Need to concatenate multiple channels here if you want multichannel segmentation # Check other examples on Monai webpage. if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # # try to use all the available GPUs # devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # loss_function = monai.losses.DiceLoss(sigmoid=True) # loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) loss_function = monai.losses.DiceCELoss(sigmoid=True) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_samples): if mode != "test": keys = ["image", "label"] else: keys = ["image"] load_transforms = [ LoadImaged(keys=keys), EnsureChannelFirstd(keys=keys), ] # 2. sampling sample_transforms = [ PreprocessAnisotropic( keys=keys, clip_values=clip_values[task_id], pixdim=spacing[task_id], normalize_values=normalize_values[task_id], model_mode=mode, ), ] # 3. spatial transforms if mode == "train": other_transforms = [ SpatialPadd(keys=["image", "label"], spatial_size=patch_size[task_id]), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=patch_size[task_id], pos=pos_sample_num, neg=neg_sample_num, num_samples=num_samples, image_key="image", image_threshold=0, ), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("trilinear", "nearest"), align_corners=(True, None), prob=0.15, ), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] elif mode == "validation": other_transforms = [ CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] else: other_transforms = [ CastToTyped(keys=["image"], dtype=(np.float32)), EnsureTyped(keys=["image"]), ] all_transforms = load_transforms + sample_transforms + other_transforms return Compose(all_transforms)
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) set_determinism(seed=0) logging.basicConfig(stream=sys.stdout, level=logging.INFO) device = torch.device(opt.gpu_id) # ------- Data loader creation ---------- # images images = sorted(glob(os.path.join(opt.images_folder, 'image*.nii'))) segs = sorted(glob(os.path.join(opt.labels_folder, 'label*.nii'))) train_files = [] val_files = [] for i in range(opt.models_ensemble): train_files.append([{ "image": img, "label": seg } for img, seg in zip( images[:(opt.split_val * i)] + images[(opt.split_val * (i + 1)):(len(images) - opt.split_val)], segs[:(opt.split_val * i)] + segs[(opt.split_val * (i + 1)):(len(images) - opt.split_val)])]) val_files.append([{ "image": img, "label": seg } for img, seg in zip( images[(opt.split_val * i):(opt.split_val * (i + 1))], segs[(opt.split_val * i):(opt.split_val * (i + 1))])]) test_files = [{ "image": img, "label": seg } for img, seg in zip(images[(len(images) - opt.split_test):len(images)], segs[( len(images) - opt.split_test):len(images)])] # ----------- Transforms list -------------- if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0, 0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # ---------- Creation of DataLoaders ------------- train_dss = [ CacheDataset(data=train_files[i], transform=train_transforms) for i in range(opt.models_ensemble) ] train_loaders = [ DataLoader(train_dss[i], batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) for i in range(opt.models_ensemble) ] val_dss = [ CacheDataset(data=val_files[i], transform=val_transforms) for i in range(opt.models_ensemble) ] val_loaders = [ DataLoader(val_dss[i], batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) for i in range(opt.models_ensemble) ] test_ds = CacheDataset(data=test_files, transform=val_transforms) test_loader = DataLoader(test_ds, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) def train(index): # ---------- Build the nn-Unet network ------------ if opt.resolution is None: sizes, spacings = opt.patch_size, opt.spacing else: sizes, spacings = opt.patch_size, opt.resolution strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = monai.networks.nets.DynUNet( spatial_dims=3, in_channels=opt.in_channels, out_channels=opt.out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], res_block=True, # act=act_type, # norm=Norm.BATCH, ).to(device) from torch.autograd import Variable from torchsummaryX import summary data = Variable( torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda() out = net(data) summary(net, data) print("out size: {}".format(out.size())) # if opt.preload is not None: # net.load_state_dict(torch.load(opt.preload)) # ---------- ------------------------ ------------ optim = torch.optim.Adam(net.parameters(), lr=opt.lr) lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9) loss_function = monai.losses.DiceCELoss(sigmoid=True) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loaders[index], network=net, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ ValidationHandler(validator=evaluator, interval=5, epoch_level=True), LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": optim }, save_final=True, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=opt.epochs, train_data_loader=train_loaders[index], network=net, optimizer=optim, loss_function=loss_function, inferer=SimpleInferer(), post_transform=train_post_transforms, amp=False, train_handlers=train_handlers, ) trainer.run() return net models = [train(i) for i in range(opt.models_ensemble)] # -------- Test the models --------- def ensemble_evaluate(post_transforms, models): evaluator = EnsembleEvaluator( device=device, val_data_loader=test_loader, pred_keys=opt.pred_keys, networks=models, inferer=SlidingWindowInferer(roi_size=opt.patch_size, sw_batch_size=opt.batch_size, overlap=0.5), post_transform=post_transforms, key_val_metric={ "test_mean_dice": MeanDice( include_background=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, ) evaluator.run() mean_post_transforms = Compose([ MeanEnsembled( keys=opt.pred_keys, output_key="pred", # in this particular example, we use validation metrics as weights weights=opt.weights_models, ), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]) ]) print('Results from MeanEnsembled:') ensemble_evaluate(mean_post_transforms, models) vote_post_transforms = Compose([ Activationsd(keys=opt.pred_keys, sigmoid=True), # transform data into discrete before voting AsDiscreted(keys=opt.pred_keys, threshold_values=True), # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), VoteEnsembled(keys=opt.pred_keys, output_key="pred"), ]) print('Results from VoteEnsembled:') ensemble_evaluate(vote_post_transforms, models)
def main(): parser = argparse.ArgumentParser(description="training") parser.add_argument( "--checkpoint", type=str, default=None, help="checkpoint full path", ) parser.add_argument( "--factor_ram_cost", default=0.0, type=float, help="factor to determine RAM cost in the searched architecture", ) parser.add_argument( "--fold", action="store", required=True, help="fold index in N-fold cross-validation", ) parser.add_argument( "--json", action="store", required=True, help="full path of .json file", ) parser.add_argument( "--json_key", action="store", required=True, help="selected key in .json data list", ) parser.add_argument( "--local_rank", required=int, help="local process rank", ) parser.add_argument( "--num_folds", action="store", required=True, help="number of folds in cross-validation", ) parser.add_argument( "--output_root", action="store", required=True, help="output root", ) parser.add_argument( "--root", action="store", required=True, help="data root", ) args = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): os.makedirs(args.output_root, exist_ok=True) amp = True determ = True factor_ram_cost = args.factor_ram_cost fold = int(args.fold) input_channels = 1 learning_rate = 0.025 learning_rate_arch = 0.001 learning_rate_milestones = np.array([0.4, 0.8]) num_images_per_batch = 1 num_epochs = 1430 # around 20k iteration num_epochs_per_validation = 100 num_epochs_warmup = 715 num_folds = int(args.num_folds) num_patches_per_image = 1 num_sw_batch_size = 6 output_classes = 3 overlap_ratio = 0.625 patch_size = (96, 96, 96) patch_size_valid = (96, 96, 96) spacing = [1.0, 1.0, 1.0] print("factor_ram_cost", factor_ram_cost) # deterministic training if determ: set_determinism(seed=0) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") # dist.barrier() world_size = dist.get_world_size() with open(args.json, "r") as f: json_data = json.load(f) split = len(json_data[args.json_key]) // num_folds list_train = json_data[args.json_key][:( split * fold)] + json_data[args.json_key][(split * (fold + 1)):] list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] # training data files = [] for _i in range(len(list_train)): str_img = os.path.join(args.root, list_train[_i]["image"]) str_seg = os.path.join(args.root, list_train[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) train_files = files random.shuffle(train_files) train_files_w = train_files[:len(train_files) // 2] train_files_w = partition_dataset(data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_w:", len(train_files_w)) train_files_a = train_files[len(train_files) // 2:] train_files_a = partition_dataset(data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_a:", len(train_files_a)) # validation data files = [] for _i in range(len(list_valid)): str_img = os.path.join(args.root, list_valid[_i]["image"]) str_seg = os.path.join(args.root, list_valid[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) val_files = files val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] print("val_files:", len(val_files)) # network architecture device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), CopyItemsd(keys=["label"], times=1, names=["label4crop"]), Lambdad( keys=["label4crop"], func=lambda x: np.concatenate(tuple([ ndimage.binary_dilation( (x == _k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes) ]), axis=0), overwrite=True, ), EnsureTyped(keys=["image", "label"]), CastToTyped(keys=["image"], dtype=(torch.float32)), SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), RandCropByLabelClassesd(keys=["image", "label"], label_key="label4crop", num_classes=output_classes, ratios=[ 1, ] * output_classes, spatial_size=patch_size, num_samples=num_patches_per_image), Lambdad(keys=["label4crop"], func=lambda x: 0), RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ToTensord(keys=["image", "label"]) ]) train_ds_a = monai.data.CacheDataset(data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8) train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited. # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) # DataLoader can be used as alternatives when ThreadDataLoader is less efficient. # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) dints_space = monai.networks.nets.TopologySearch( channel_mul=0.5, num_blocks=12, num_depths=4, use_downsample=True, device=device, ) model = monai.networks.nets.DiNTS( dints_space=dints_space, in_channels=input_channels, num_classes=output_classes, use_downsample=True, ) model = model.to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) # loss function loss_func = monai.losses.DiceCELoss( include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch=True, smooth_nr=0.00001, smooth_dr=0.00001, ) # optimizer optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004) arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) print() if torch.cuda.device_count() > 1: if dist.get_rank() == 0: print("Let's use", torch.cuda.device_count(), "GPUs!") model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) if args.checkpoint != None and os.path.isfile(args.checkpoint): print("[info] fine-tuning pre-trained checkpoint {0:s}".format( args.checkpoint)) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) torch.cuda.empty_cache() else: print("[info] training from scratch") # amp if amp: from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() if dist.get_rank() == 0: print("[info] amp enabled") # start a typical PyTorch training val_interval = num_epochs_per_validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() idx_iter = 0 metric_values = list() if dist.get_rank() == 0: writer = SummaryWriter( log_dir=os.path.join(args.output_root, "Events")) with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") dataloader_a_iterator = iter(train_loader_a) start_time = time.time() for epoch in range(num_epochs): decay = 0.5**np.sum([ (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones ]) lr = learning_rate * decay for param_group in optimizer.param_groups: param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) print(f"epoch {epoch + 1}/{num_epochs}") print("learning rate is set to {}".format(lr)) model.train() epoch_loss = 0 loss_torch = torch.zeros(2, dtype=torch.float, device=device) epoch_loss_arch = 0 loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) step = 0 for batch_data in train_loader_w: step += 1 inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = True else: for _ in model.module.weight_parameters(): _.requires_grad = True dints_space.log_alpha_a.requires_grad = False dints_space.log_alpha_c.requires_grad = False optimizer.zero_grad() if amp: with autocast(): outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() loss_torch[0] += loss.item() loss_torch[1] += 1.0 epoch_len = len(train_loader_w) idx_iter += 1 if dist.get_rank() == 0: print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) if epoch < num_epochs_warmup: continue try: sample_a = next(dataloader_a_iterator) except StopIteration: dataloader_a_iterator = iter(train_loader_a) sample_a = next(dataloader_a_iterator) inputs_search, labels_search = sample_a["image"].to( device), sample_a["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = False else: for _ in model.module.weight_parameters(): _.requires_grad = False dints_space.log_alpha_a.requires_grad = True dints_space.log_alpha_c.requires_grad = True # linear increase topology and RAM loss entropy_alpha_c = torch.tensor(0.).to(device) entropy_alpha_a = torch.tensor(0.).to(device) ram_cost_full = torch.tensor(0.).to(device) ram_cost_usage = torch.tensor(0.).to(device) ram_cost_loss = torch.tensor(0.).to(device) topology_loss = torch.tensor(0.).to(device) probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \ F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean() topology_loss = dints_space.get_topology_entropy(probs_a) ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) ram_cost_loss = torch.abs(factor_ram_cost - ram_cost_usage / ram_cost_full) arch_optimizer_a.zero_grad() arch_optimizer_c.zero_grad() combination_weights = (epoch - num_epochs_warmup) / ( num_epochs - num_epochs_warmup) if amp: with autocast(): outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) scaler.scale(loss).backward() scaler.step(arch_optimizer_a) scaler.step(arch_optimizer_c) scaler.update() else: outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) loss.backward() arch_optimizer_a.step() arch_optimizer_c.step() epoch_loss_arch += loss.item() loss_torch_arch[0] += loss.item() loss_torch_arch[1] += 1.0 if dist.get_rank() == 0: print( "[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}") writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) loss_torch = loss_torch.tolist() loss_torch_arch = loss_torch_arch.tolist() if dist.get_rank() == 0: loss_torch_epoch = loss_torch[0] / loss_torch[1] print( f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if epoch >= num_epochs_warmup: loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] print( f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if (epoch + 1) % val_interval == 0: torch.cuda.empty_cache() model.eval() with torch.no_grad(): metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) metric_sum = 0.0 metric_count = 0 metric_mat = [] val_images = None val_labels = None val_outputs = None _index = 0 for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) roi_size = patch_size_valid sw_batch_size = num_sw_batch_size if amp: with torch.cuda.amp.autocast(): pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) else: pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) val_outputs = pred val_outputs = post_pred(val_outputs[0, ...]) val_outputs = val_outputs[None, ...] val_labels = post_label(val_labels[0, ...]) val_labels = val_labels[None, ...] value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) print(_index + 1, "/", len(val_loader), value) metric_count += len(value) metric_sum += value.sum().item() metric_vals = value.cpu().numpy() if len(metric_mat) == 0: metric_mat = metric_vals else: metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) for _c in range(output_classes - 1): val0 = torch.nan_to_num(value[0, _c], nan=0.0) val1 = 1.0 - torch.isnan(value[0, 0]).float() metric[2 * _c] += val0 * val1 metric[2 * _c + 1] += val1 _index += 1 # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if dist.get_rank() == 0: for _c in range(output_classes - 1): print( "evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) avg_metric = 0 for _c in range(output_classes - 1): avg_metric += metric[2 * _c] / metric[2 * _c + 1] avg_metric = avg_metric / float(output_classes - 1) print("avg_metric", avg_metric) if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 best_metric_iterations = idx_iter node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode( ) torch.save( { "node_a": node_a_d, "arch_code_a": arch_code_a_d, "arch_code_a_max": arch_code_a_max_d, "arch_code_c": arch_code_c_d, "iter_num": idx_iter, "epochs": epoch + 1, "best_dsc": best_metric, "best_path": best_metric_iterations, }, os.path.join(args.output_root, "search_code_" + str(idx_iter) + ".pth"), ) print("saved new best metric model") dict_file = {} dict_file["best_avg_dice_score"] = float(best_metric) dict_file["best_avg_dice_score_epoch"] = int( best_metric_epoch) dict_file["best_avg_dice_score_iteration"] = int(idx_iter) with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: documents = yaml.dump(dict_file, stream=out_file) print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, avg_metric, best_metric, best_metric_epoch)) current_time = time.time() elapsed_time = (current_time - start_time) / 60.0 with open( os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write( "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n" .format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) dist.barrier() torch.cuda.empty_cache() print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) if dist.get_rank() == 0: writer.close() dist.destroy_process_group() return
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # check gpus if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # CT HU filter # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # build the network if opt.network is 'nnunet': net = build_net() # nn build_net elif opt.network is 'unetr': net = build_UNETR() # UneTR net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) loss_function = monai.losses.DiceCELoss(sigmoid=True) torch.backends.cudnn.benchmark = opt.benchmark if opt.network is 'nnunet': optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,) net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9) elif opt.network is 'unetr': optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if opt.network is 'nnunet': update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label # val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image") plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label") plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()