def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) device = torch.device('cuda:0') model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load('best_metric_model.pth')) model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 saver = NiftiSaver(output_dir='./output') for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print('evaluation metric:', metric) shutil.rmtree(tempdir)
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print("evaluation metric:", metric)
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) # define transforms for image and segmentation imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()]) val_ds = ArrayDataset(images, imtrans, segs, segtrans) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth")) model.eval() with torch.no_grad(): for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] val_labels = decollate_batch(val_labels) # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) for val_output in val_outputs: saver(val_output) # aggregate the final mean dice result print("evaluation metric:", dice_metric.aggregate().item()) # reset the status dice_metric.reset()
import monai from monai import config from monai.data.nifti_reader import NiftiDataset from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) from monai.handlers.stats_handler import StatsHandler from monai.handlers.mean_dice import MeanDice from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric # assumes the framework is found here, change as necessary sys.path.append("..") config.print_config() # Create a temporary directory and 50 random image, mask paris tempdir = tempfile.mkdtemp() for i in range(50): im, seg = create_test_image_3d(256, 256, 256) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
def main(tempdir): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) seg_probs = post_trans(seg_probs) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice().attach(evaluator, "Mean_Dice") # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg", name="evaluator", batch_transform=lambda x: x[2], output_transform=lambda output: output[0], ) file_saver.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader( load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) print(state)
AddChanneld, Compose, CropForegroundd, LoadImaged, KeepLargestConnectedComponent, LabelToContour, Orientationd, RandCropByPosNegLabeld, ScaleIntensityRanged, Spacingd, ToTensord, ) from monai.utils import first, set_determinism from numpy import math print_config() print("KIDNEYS") """## Setup data directory You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. This allows you to save results and reuse downloads. If not specified a temporary directory will be used. """ """## Download dataset Downloads and extracts the dataset. The dataset comes from http://medicaldecathlon.com/. """ md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device('cuda:0') net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) file_saver.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) shutil.rmtree(tempdir)
def main(): print_config() # Define paths for running the script data_dir = os.path.normpath('/to/be/defined') json_path = os.path.normpath('/to/be/defined') logdir = os.path.normpath('/to/be/defined') # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used use_pretrained = 1 pretrained_path = os.path.normpath('/to/be/defined') # Training Hyper-parameters lr = 1e-4 max_iterations = 30000 eval_num = 100 if os.path.exists(logdir) == False: os.mkdir(logdir) # Training & Validation Transform chain train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandFlipd( keys=["image", "label"], spatial_axis=[0], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[1], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[2], prob=0.10, ), RandRotate90d( keys=["image", "label"], prob=0.10, max_k=3, ), RandShiftIntensityd( keys=["image"], offsets=0.10, prob=0.50, ), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) datalist = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="training") val_files = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="validation") train_ds = CacheDataset( data=datalist, transform=train_transforms, cache_num=24, cache_rate=1.0, num_workers=4, ) train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) case_num = 0 img = val_ds[case_num]["image"] label = val_ds[case_num]["label"] img_shape = img.shape label_shape = label.shape print(f"image shape: {img_shape}, label shape: {label_shape}") os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNETR( in_channels=1, out_channels=14, img_size=(96, 96, 96), feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed="conv", norm_name="instance", res_block=True, dropout_rate=0.0, ) # Load ViT backbone weights into UNETR if use_pretrained == 1: print('Loading Weights from the Path {}'.format(pretrained_path)) vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] # Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias, # conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions # while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR) vit_weights.pop('conv3d_transpose_1.bias') vit_weights.pop('conv3d_transpose_1.weight') vit_weights.pop('conv3d_transpose.bias') vit_weights.pop('conv3d_transpose.weight') model.vit.load_state_dict(vit_weights) print('Pretrained Weights Succesfully Loaded !') elif use_pretrained == 0: print( 'No weights were loaded, all weights being used are randomly initialized!' ) model.to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True) torch.backends.cudnn.benchmark = True optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) post_label = AsDiscrete(to_onehot=14) post_pred = AsDiscrete(argmax=True, to_onehot=14) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) global_step = 0 dice_val_best = 0.0 global_step_best = 0 epoch_loss_values = [] metric_values = [] def validation(epoch_iterator_val): model.eval() dice_vals = list() with torch.no_grad(): for step, batch in enumerate(epoch_iterator_val): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [ post_label(val_label_tensor) for val_label_tensor in val_labels_list ] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [ post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list ] dice_metric(y_pred=val_output_convert, y=val_labels_convert) dice = dice_metric.aggregate().item() dice_vals.append(dice) epoch_iterator_val.set_description( "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)) dice_metric.reset() mean_dice_val = np.mean(dice_vals) return mean_dice_val def train(global_step, train_loader, dice_val_best, global_step_best): model.train() epoch_loss = 0 step = 0 epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() epoch_loss += loss.item() optimizer.step() optimizer.zero_grad() epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)) if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations: epoch_iterator_val = tqdm( val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) dice_val = validation(epoch_iterator_val) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) if dice_val > dice_val_best: dice_val_best = dice_val global_step_best = global_step torch.save(model.state_dict(), os.path.join(logdir, "best_metric_model.pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) else: print( "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) plt.figure(1, (12, 6)) plt.subplot(1, 2, 1) plt.title("Iteration Average Loss") x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.subplot(1, 2, 2) plt.title("Val Mean Dice") x = [eval_num * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.savefig( os.path.join(logdir, 'btcv_finetune_quick_update.png')) plt.clf() plt.close(1) global_step += 1 return global_step, dice_val_best, global_step_best while global_step < max_iterations: global_step, dice_val_best, global_step_best = train( global_step, train_loader, dice_val_best, global_step_best) model.load_state_dict( torch.load(os.path.join(logdir, "best_metric_model.pth"))) print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
def main(train_output): logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # Setup directories dirs = setup_directories() # Setup torch device device, using_gpu = create_device("cuda") # Load and randomize images # HACKATON image and segmentation data hackathon_dir = os.path.join(dirs["data"], 'HACKATHON') map_fn = lambda x: (x[0], int(x[1])) with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp: train_info_hackathon = [ map_fn(entry.strip().split(',')) for entry in fp.readlines() ] image_dir = os.path.join(hackathon_dir, 'images', 'train') seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train') _train_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon, dual_output=False) large_image_splitter(_train_data_hackathon, dirs["cache"]) balance_training_data(_train_data_hackathon, seed=72) # PSUF data """psuf_dir = os.path.join(dirs["data"], 'psuf') with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp: train_info = [entry.strip().split(',') for entry in fp.readlines()] image_dir = os.path.join(psuf_dir, 'images') train_data_psuf = get_data_from_info(image_dir, None, train_info)""" # Split data into train, validate and test train_split, test_data_hackathon = train_test_split(_train_data_hackathon, test_size=0.2, shuffle=True, random_state=42) #train_data_hackathon, valid_data_hackathon = train_test_split(train_split, test_size=0.2, shuffle=True, random_state=43) # Setup transforms # Crop foreground crop_foreground = CropForegroundd( keys=["image"], source_key="image", margin=(5, 5, 0), #select_fn = lambda x: x != 0 ) # Crop Z crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12)) # Window width and level (window center) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30)) resize = Resized(keys=["image"], spatial_size=(int(512 * 0.50), int(512 * 0.50), -1), mode="trilinear") # Create transforms common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), resize, crop_foreground, crop_z, spatial_pad, ]) hackathon_train_transfrom = Compose([ common_transform, ToTensord(keys=["image"]), ]).flatten() psuf_transforms = Compose([ LoadImaged(keys=["image"]), AddChanneld(keys=["image"]), ToTensord(keys=["image"]), ]) # Setup data #set_determinism(seed=100) test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_train_transfrom, cache_dir=dirs["persistent"]) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True, pin_memory=using_gpu, num_workers=1, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Setup network, loss function, optimizer and scheduler network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) # Setup validator and trainer valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), ]) # Setup tester tester = Tester(device=device, test_data_loader=test_loader, load_dir=train_output, out_dir=dirs["out"], network=network, post_transform=valid_post_transforms, non_blocking=using_gpu, amp=using_gpu) # Run tester tester.run()
def run_inference(input_data, config_info): """ Pipeline to run inference with MONAI dynUNet model. The pipeline reads the input filenames, applies the required preprocessing and creates the pytorch dataloader; it then performs evaluation on each input file using a trained dynUNet model (random flipping augmentation is applied at inference). It uses the dynUNet model implemented in the MONAI framework (https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/nets/dynunet.py) which is inspired by the nnU-Net framework (https://arxiv.org/abs/1809.10486) Inference is performed in 2D slice-by-slice, all slices are then recombined together into the 3D volume. Args: input_data: str or list of strings, filenames of images to be processed config_info: dict, contains the configuration parameters to reload the trained model """ """ Read input and configuration parameters """ val_files = create_data_list_of_dictionaries(input_data) # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print("*** MONAI config: ") print_config() # print to log the parameter setups print("*** Network inference config: ") print(yaml.dump(config_info)) # inference params nr_out_channels = config_info['inference']['nr_out_channels'] spacing = config_info["inference"]["spacing"] prob_thr = config_info['inference']['probability_threshold'] model_to_load = config_info['inference']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError('Trained model not found') patch_size = config_info["inference"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}".format( torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) """ Data Preparation """ print("*** Preparing data ... ") # data preprocessing for inference: # - convert data to right format [batch, channel, dim, dim, dim] # - resample to the training resolution in-plane (not along z) # - apply whitening # - convert to tensor val_transforms = Compose([ LoadNiftid(keys=["image"]), AddChanneld(keys=["image"]), InPlaneSpacingd( keys=["image"], pixdim=spacing, mode="bilinear", ), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image"]), ]) # create a validation data loader val_ds = Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=config_info['device']['num_workers']) def prepare_batch(batchdata): assert isinstance(batchdata, dict), "prepare_batch expects dictionary input data." return ((batchdata["image"], batchdata["label"]) if "label" in batchdata else (batchdata["image"], None)) """ Network preparation """ print("*** Preparing network ... ") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) net = DynUNet(spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False).to(current_device) """ Set ignite evaluator to perform inference """ print("*** Preparing evaluator ... ") if nr_out_channels == 1: do_sigmoid = True do_softmax = False elif nr_out_channels > 1: do_sigmoid = False do_softmax = True else: raise Exception("incompatible number of output channels") print("Using sigmoid={} and softmax={} as final activation".format( do_sigmoid, do_softmax)) val_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=do_sigmoid, softmax=do_softmax), AsDiscreted(keys="pred", argmax=True, threshold_values=True, logit_thresh=prob_thr), KeepLargestConnectedComponentd(keys="pred", applied_labels=1) ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=model_to_load, load_dict={"net": net}, map_location=torch.device('cpu')), SegmentationSaver( output_dir=config_info['output']['out_dir'], output_ext='.nii.gz', output_postfix=config_info['output']['out_postfix'], batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), ] # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs = inputs.to(engine.state.device) if targets is not None: targets = targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2, )) flip_inputs_2 = torch.flip(inputs, dims=(3, )) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2, )) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3, )) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, prepare_batch=prepare_batch, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=val_post_transforms, val_handlers=val_handlers, amp=False, ) """ Run inference """ print("*** Running evaluator ... ") evaluator.run() print("Done!") return
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), EnsureTyped(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), Invertd( keys= "pred", # invert the `pred` data field, also support multiple fields transform=pre_transforms, orig_keys= "img", # get the previously applied pre_transforms information on the `img` data field, # then invert `pred` based on this information. we can use same info # for multiple fields, also support different orig_keys for different fields meta_keys= "pred_meta_dict", # key field to save inverted meta data, every item maps to `keys` orig_meta_keys= "img_meta_dict", # get the meta data from `img_meta_dict` field when inverting, # for example, may need the `affine` to invert `Spacingd` transform, # multiple fields can use the same meta data to invert meta_key_postfix= "meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key, # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}", # otherwise, no need this arg during inverting nearest_interp= False, # don't change the interpolation mode to "nearest" when inverting transforms # to ensure a smooth output, then execute `AsDiscreted` transform to_tensor=True, # convert to PyTorch Tensor after inverting ), AsDiscreted(keys="pred", threshold=0.5), SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # decollate the batch data into a list of dictionaries, then execute postprocessing transforms d = [post_transforms(i) for i in decollate_batch(d)]
def main(tempdir): print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, _ = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) files = [{"img": img} for img in images] # define pre transforms pre_transforms = Compose([ LoadImaged(keys="img"), EnsureChannelFirstd(keys="img"), Orientationd(keys="img", axcodes="RAS"), Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True), ScaleIntensityd(keys="img"), ToTensord(keys="img"), ]) # define dataset and dataloader dataset = Dataset(data=files, transform=pre_transforms) dataloader = DataLoader(dataset, batch_size=2, num_workers=4) # define post transforms post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), Invertd(keys="pred", transform=pre_transforms, loader=dataloader, orig_keys="img", nearest_interp=True), SaveImaged(keys="pred_inverted", output_dir="./output", output_postfix="seg", resample=False), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict( torch.load("best_metric_model_segmentation3d_dict.pth")) net.eval() with torch.no_grad(): for d in dataloader: images = d["img"].to(device) # define sliding window size and batch size for windows inference d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net) # execute post transforms to invert spatial transforms and save to NIfTI files post_transforms(d)
def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # Setup directories dirs = setup_directories() # Setup torch device device, using_gpu = create_device("cuda") # Load and randomize images # HACKATON image and segmentation data hackathon_dir = os.path.join(dirs["data"], 'HACKATHON') map_fn = lambda x: (x[0], int(x[1])) with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp: train_info_hackathon = [ map_fn(entry.strip().split(',')) for entry in fp.readlines() ] image_dir = os.path.join(hackathon_dir, 'images', 'train') seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train') _train_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon, dual_output=False) _train_data_hackathon = large_image_splitter(_train_data_hackathon, dirs["cache"]) copy_list = transform_and_copy(_train_data_hackathon, dirs['cache']) balance_training_data2(_train_data_hackathon, copy_list, seed=72) # PSUF data """psuf_dir = os.path.join(dirs["data"], 'psuf') with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp: train_info = [entry.strip().split(',') for entry in fp.readlines()] image_dir = os.path.join(psuf_dir, 'images') train_data_psuf = get_data_from_info(image_dir, None, train_info)""" # Split data into train, validate and test train_split, test_data_hackathon = train_test_split(_train_data_hackathon, test_size=0.2, shuffle=True, random_state=42) train_data_hackathon, valid_data_hackathon = train_test_split( train_split, test_size=0.2, shuffle=True, random_state=43) #balance_training_data(train_data_hackathon, seed=72) #balance_training_data(valid_data_hackathon, seed=73) #balance_training_data(test_data_hackathon, seed=74) # Setup transforms # Crop foreground crop_foreground = CropForegroundd(keys=["image"], source_key="image", margin=(5, 5, 0), select_fn=lambda x: x != 0) # Crop Z crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12)) # Window width and level (window center) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) # Random axis flip rand_x_flip = RandFlipd(keys=["image"], spatial_axis=0, prob=0.50) rand_y_flip = RandFlipd(keys=["image"], spatial_axis=1, prob=0.50) rand_z_flip = RandFlipd(keys=["image"], spatial_axis=2, prob=0.50) # Rand affine transform rand_affine = RandAffined(keys=["image"], prob=0.5, rotate_range=(0, 0, np.pi / 12), shear_range=(0.07, 0.07, 0.0), translate_range=(0, 0, 0), scale_range=(0.07, 0.07, 0.0), padding_mode="zeros") # Pad image to have hight at least 30 spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30)) resize = Resized(keys=["image"], spatial_size=(int(512 * 0.50), int(512 * 0.50), -1), mode="trilinear") # Apply Gaussian noise rand_gaussian_noise = RandGaussianNoised(keys=["image"], prob=0.25, mean=0.0, std=0.1) # Create transforms common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), resize, crop_foreground, crop_z, spatial_pad, ]) hackathon_train_transform = Compose([ common_transform, rand_x_flip, rand_y_flip, rand_z_flip, rand_affine, rand_gaussian_noise, ToTensord(keys=["image"]), ]).flatten() hackathon_valid_transfrom = Compose([ common_transform, #rand_x_flip, #rand_y_flip, #rand_z_flip, #rand_affine, ToTensord(keys=["image"]), ]).flatten() hackathon_test_transfrom = Compose([ common_transform, ToTensord(keys=["image"]), ]).flatten() psuf_transforms = Compose([ LoadImaged(keys=["image"]), AddChanneld(keys=["image"]), ToTensord(keys=["image"]), ]) # Setup data #set_determinism(seed=100) train_dataset = PersistentDataset(data=train_data_hackathon[:], transform=hackathon_train_transform, cache_dir=dirs["persistent"]) valid_dataset = PersistentDataset(data=valid_data_hackathon[:], transform=hackathon_valid_transfrom, cache_dir=dirs["persistent"]) test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_test_transfrom, cache_dir=dirs["persistent"]) train_loader = DataLoader( train_dataset, batch_size=4, #shuffle=True, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( train_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader( valid_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( valid_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Setup network, loss function, optimizer and scheduler network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) # pos_weight for class imbalance _, n, p = calculate_class_imbalance(train_data_hackathon) pos_weight = torch.Tensor([n, p]).to(device) loss_function = torch.nn.BCEWithLogitsLoss(pos_weight) optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=0) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1) # Setup validator and trainer valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), #Activationsd(keys="pred", softmax=True), ]) validator = Validator(device=device, val_data_loader=valid_loader, network=network, post_transform=valid_post_transforms, amp=using_gpu, non_blocking=using_gpu) trainer = Trainer(device=device, out_dir=dirs["out"], out_name="DenseNet121", max_epochs=120, validation_epoch=1, validation_interval=1, train_data_loader=train_loader, network=network, optimizer=optimizer, loss_function=loss_function, lr_scheduler=None, validator=validator, amp=using_gpu, non_blocking=using_gpu) """x_max, y_max, z_max, size_max = 0, 0, 0, 0 for data in valid_loader: image = data["image"] label = data["label"] print() print(len(data['image_transforms'])) #print(data['image_transforms']) print(label) shape = image.shape x_max = max(x_max, shape[-3]) y_max = max(y_max, shape[-2]) z_max = max(z_max, shape[-1]) size = int(image.nelement()*image.element_size()/1024/1024) size_max = max(size_max, size) print("shape:", shape, "size:", str(size)+"MB") #multi_slice_viewer(image[0, 0, :, :, :], str(label)) print(x_max, y_max, z_max, str(size_max)+"MB") exit()""" # Run trainer train_output = trainer.run() # Setup tester tester = Tester(device=device, test_data_loader=test_loader, load_dir=train_output, out_dir=dirs["out"], network=network, post_transform=valid_post_transforms, non_blocking=using_gpu, amp=using_gpu) # Run tester tester.run()