def __init__( self, output_dir: str = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale=None, dtype: Optional[np.dtype] = None, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, ): """ Args: output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. resample: whether to resample before saving the data array. mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - PNG files This option is ignored. scale (255, 65535): postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype (np.dtype, optional): convert the image data to save to this data type. If None, keep the original type of data. It's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the ignite.engine.output into the form expected image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ self.saver: Union[NiftiSaver, PNGSaver] if output_ext in (".nii.gz", ".nii"): self.saver = NiftiSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, mode=mode, padding_mode=padding_mode, dtype=dtype, ) elif output_ext == ".png": self.saver = PNGSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, mode=mode, scale=scale, ) self.batch_transform = batch_transform self.output_transform = output_transform self.logger = None if name is None else logging.getLogger(name) self._name = name
def __init__( self, output_dir: str = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, interp_order: str = "nearest", mode: str = "border", scale=None, dtype: Optional[np.dtype] = None, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, ): """ Args: output_dir: output image directory. output_postfix: a string appended to all output file names. output_ext: output file extension name. resample: whether to resample before saving the data array. interp_order: The interpolation mode. Defaults to "nearest". This option is used when `resample = True`. When saving NIfTI files, the available options are "nearest", "bilinear" See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample. When saving PNG files, the available options are "nearest", "bilinear", "bicubic", "area". See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate. mode: The mode parameter determines how the input array is extended beyond its boundaries. This option is used when `resample = True`. When saving NIfTI files, the options are "zeros", "border", "reflection". Default is "border". When saving PNG files, the options is ignored. scale (255, 65535): postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. It's used for PNG format only. dtype (np.dtype, optional): convert the image data to save to this data type. If None, keep the original type of data. It's used for Nifti format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the ignite.engine.output into the form expected image data. The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. """ self.saver: Union[NiftiSaver, PNGSaver] if output_ext in (".nii.gz", ".nii"): self.saver = NiftiSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, interp_order=interp_order, mode=mode, dtype=dtype, ) elif output_ext == ".png": self.saver = PNGSaver( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, resample=resample, interp_order=interp_order, scale=scale, ) self.batch_transform = batch_transform self.output_transform = output_transform self.logger = None if name is None else logging.getLogger(name) self._name = name
def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) device = torch.device("cuda:0") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to( device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print("evaluation metric:", metric) shutil.rmtree(tempdir)
def main(): """ Read input and configuration parameters """ parser = argparse.ArgumentParser( description='Run inference with basic UNet with MONAI.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # inference params batch_size_inference = config_info['inference']['batch_size_inference'] # temporary check as sliding window inference does not accept higher batch size assert batch_size_inference == 1 prob_thr = config_info['inference']['probability_threshold'] model_to_load = config_info['inference']['model_to_load'] if not os.path.exists(model_to_load): raise IOError('Trained model not found') # data params data_root = config_info['data']['data_root'] inference_list = config_info['data']['inference_list'] # output saving out_dir = config_info['output']['out_dir'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) """ Data Preparation """ val_files = create_data_list(data_folder_list=data_root, subject_list=inference_list, img_postfix='_Image', is_inference=True) print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for inference: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - NOTE: resizing needs to be applied afterwards, otherwise it cannot be remapped back to original size val_transforms = Compose([ LoadNiftid(keys=['img']), AddChanneld(keys=['img']), NormalizeIntensityd(keys=['img']), ToTensord(keys=['img']) ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size_inference, num_workers=num_workers) """ Network preparation """ device = torch.cuda.current_device() # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict(torch.load(model_to_load)) net.eval() """ Run inference """ with torch.no_grad(): saver = NiftiSaver(output_dir=out_dir) for val_data in val_loader: val_images = val_data['img'].to(device) orig_size = list(val_images.shape) resized_size = copy.deepcopy(orig_size) resized_size[2] = 96 resized_size[3] = 96 val_images_resize = torch.nn.functional.interpolate( val_images, size=resized_size[2:], mode='trilinear') # define sliding window size and batch size for windows inference roi_size = (96, 96, 1) val_outputs = sliding_window_inference(val_images_resize, roi_size, batch_size_inference, net) val_outputs = (val_outputs.sigmoid() >= prob_thr).float() val_outputs_resized = torch.nn.functional.interpolate( val_outputs, size=orig_size[2:], mode='nearest') # add post-processing val_outputs_resized = val_outputs_resized.detach().cpu().numpy() strt = ndimage.generate_binary_structure(3, 2) post = padded_binary_closing(np.squeeze(val_outputs_resized), strt) post = get_largest_component(post) val_outputs_resized = val_outputs_resized * post # out = np.zeros(img.shape[:-1], np.uint8) # out = set_ND_volume_roi_with_bounding_box_range(out, bb_min, bb_max, out_roi) saver.save_batch( val_outputs_resized, { 'filename_or_obj': val_data['img.filename_or_obj'], 'affine': val_data['img.affine'] })
device = torch.device('cuda:0') model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load('best_metric_model.pth')) model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 saver = NiftiSaver(output_dir='./output') for val_data in val_loader: val_images, val_labels = val_data['img'].to( device), val_data['seg'].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item()
def run_inference(self, model, data_loader): logger = self.logger logger.info('Running inference...') model.eval() # activate evaluation mode of model dice_scores = np.zeros(len(data_loader)) if self.model == "UNet2d5_spvPA": model_segmentation = lambda *args, **kwargs: model( *args, **kwargs)[0] else: model_segmentation = model with torch.no_grad( ): # turns off PyTorch's auto grad for better performance for i, data in enumerate(data_loader): logger.info("starting image {}".format(i)) outputs = sliding_window_inference( inputs=data["image"].to(self.device), roi_size=self.sliding_window_inferer_roi_size, sw_batch_size=1, predictor=model_segmentation, mode="gaussian", ) dice_score = self.compute_dice_score( outputs, data["label"].to(self.device)) dice_scores[i] = dice_score.item() logger.info(f"dice_score = {dice_score.item()}") # export to nifti if self.export_inferred_segmentations: logger.info(f"export to nifti...") nifti_data_matrix = np.squeeze( torch.argmax(outputs, dim=1, keepdim=True))[None, :] data['label_meta_dict']['filename_or_obj'] = data[ 'label_meta_dict']['filename_or_obj'][0] data['label_meta_dict']['affine'] = np.squeeze( data['label_meta_dict']['affine']) data['label_meta_dict']['original_affine'] = np.squeeze( data['label_meta_dict']['original_affine']) folder_name = os.path.basename( os.path.dirname( data['label_meta_dict']['filename_or_obj'])) saver = NiftiSaver(output_dir=os.path.join( self.results_folder_path, 'inferred_segmentations_nifti', folder_name), output_postfix='') saver.save(nifti_data_matrix, meta_data=data['label_meta_dict']) # plot centre of mass slice of label label = torch.squeeze(data["label"][0, 0, :, :, :]) slice_idx = self.get_center_of_mass_slice( label ) # choose slice of selected validation set image volume for the figure plt.figure("check", (18, 6)) plt.clf() plt.subplot(1, 3, 1) plt.title("image " + str(i) + ", slice = " + str(slice_idx)) plt.imshow(data["image"][0, 0, :, :, slice_idx], cmap="gray", interpolation="none") plt.subplot(1, 3, 2) plt.title("label " + str(i)) plt.imshow(data["label"][0, 0, :, :, slice_idx], interpolation="none") plt.subplot(1, 3, 3) plt.title("output " + str(i) + f", dice = {dice_score.item():.4}") plt.imshow(torch.argmax(outputs, dim=1).detach().cpu()[0, :, :, slice_idx], interpolation="none") plt.savefig( os.path.join(self.figures_path, "best_model_output_val" + str(i) + ".png")) plt.figure("dice score histogram") plt.hist(dice_scores, bins=np.arange(0, 1.01, 0.01)) plt.savefig( os.path.join(self.figures_path, "best_model_output_dice_score_histogram.png")) logger.info(f"all_dice_scores = {dice_scores}") logger.info( f"mean_dice_score = {dice_scores.mean()} +- {dice_scores.std()}")
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose( [ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # try to use all the available GPUs devices = get_devices_spec(None) model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(devices[0]) model.load_state_dict(torch.load("best_metric_model.pth")) # if we have multiple GPUs, set data parallel to execute sliding window inference if len(devices) > 1: model = torch.nn.DataParallel(model, device_ids=devices) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data["img_meta_dict"]) metric = metric_sum / metric_count print("evaluation metric:", metric) shutil.rmtree(tempdir)
device) #2nd argument: training_batch_size or 16 with torch.no_grad(): logits = forward(model, inputs) labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True) batch_mri = inputs batch_label = labels #slices = torch.cat((batch_mri, batch_label)) #slices = torch.cat((batch_mri, batch_label),dim=1) #inf_path = 'inference.png' #save_image(slices, inf_path, nrow=training_batch_size//2, normalize=True, scale_each=True, padding=0) #display.Image(inf_path) #saver = NiftiSaver(output_dir="./niftinferece",output_postfix = str(i)) #saver.save_batch(slices) saver1 = NiftiSaver(output_dir="./inputsnifti", output_postfix=str(i)) saver2 = NiftiSaver(output_dir="./labelsnifti", output_postfix=str(i)) saver1.save_batch(inputs) saver2.save_batch(labels) #Dice score for inference slide dice_score.append( get_dice_score(F.softmax(logits, dim=CHANNELS_DIMENSION), targets)) #Dice loss for inference slide dice_losses.append( get_dice_loss(F.softmax(logits, dim=CHANNELS_DIMENSION), targets)) ## experimental to output all dice scores -- WORKS! dataset_iter = iter(validation_loader2) for i in range(5): try:
def run_inference_test(root_dir, device="cuda:0"): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), EnsureChannelFirstd(keys=["img", "seg"]), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([ ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold_values=True) ]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) # decollate prediction into a list and execute post processing for every item val_outputs = [ val_post_tran(i) for i in decollate_batch(val_outputs) ] # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) saver.save_batch(val_outputs, val_data["img_meta_dict"]) return dice_metric.aggregate().item()