def segment_volume(folder_model: str, fname_images: list, gpu_id: int = 0, options: dict = None): """Segment an image. Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest (`fname_roi`) is used to crop the image prior to segment it. Args: folder_model (str): foldername which contains (1) the model ('folder_model/folder_model.pt') to use (2) its configuration file ('folder_model/folder_model.json') used for the training, see https://github.com/neuropoly/ivadomed/wiki/configuration-file fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple images to segment, e.i., len(fname_images) > 1. gpu_id (int): Number representing gpu number if available. Currently does NOT support multiple GPU segmentation. options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS lesion classification) e.g., spinal cord centerline, used to crop the image prior to segment it if provided. The segmentation is not performed on the slices that are empty in this image. Returns: list: List of nibabel objects containing the soft segmentation(s), one per prediction class. list: List of target suffix associated with each prediction in `pred_list` """ # Check if model folder exists and get filenames to be stored as string fname_model: str fname_model_metadata: str fname_model, fname_model_metadata = imed_models.get_model_filenames( folder_model) # Load model training config context = imed_config_manager.ConfigurationManager( fname_model_metadata).get_config() postpro_list = [ 'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small' ] if options is not None and any(pp in options for pp in postpro_list): set_postprocessing_options(options, context) # LOADER loader_params = context["loader_parameters"] slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']] metadata = {} fname_roi = None fname_prior = options['fname_prior'] if (options is not None) and ( 'fname_prior' in options) else None if fname_prior is not None: if 'roi_params' in loader_params and loader_params['roi_params'][ 'suffix'] is not None: fname_roi = fname_prior # TRANSFORMATIONS metadata = process_transformations(context, fname_roi, fname_prior, metadata, slice_axis, fname_images) # Compose transforms _, _, transform_test_params = imed_transforms.get_subdatasets_transforms( context["transformation"]) tranform_lst, undo_transforms = imed_transforms.prepare_transforms( transform_test_params) # Force filter_empty_mask to False if fname_roi = None if fname_roi is None and 'filter_empty_mask' in loader_params[ "slice_filter_params"]: logger.warning( "fname_roi has not been specified, then the entire volume is processed." ) loader_params["slice_filter_params"]["filter_empty_mask"] = False filename_pairs = [(fname_images, None, fname_roi, metadata if isinstance(metadata, list) else [metadata])] kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \ not context['default_model']['is_2d'] if kernel_3D: ds = imed_loader.MRI3DSubVolumeSegmentationDataset( filename_pairs, transform=tranform_lst, length=context["Modified3DUNet"]["length_3D"], stride=context["Modified3DUNet"]["stride_3D"]) logger.info( f"Loaded {len(ds)} {loader_params['slice_axis']} volumes of shape " f"{context['Modified3DUNet']['length_3D']}.") else: ds = imed_loader.MRI2DSegmentationDataset( filename_pairs, slice_axis=slice_axis, cache=True, transform=tranform_lst, slice_filter_fn=imed_loader_utils.SliceFilter( **loader_params["slice_filter_params"])) ds.load_filenames() logger.info(f"Loaded {len(ds)} {loader_params['slice_axis']} slices.") model_params = {} if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']: onehotencoder = get_onehotencoder(context, folder_model, options, ds) model_params.update({ "name": 'FiLMedUnet', "film_onehotencoder": onehotencoder, "n_metadata": len([ll for l in onehotencoder.categories_ for ll in l]) }) # Data Loader data_loader = DataLoader( ds, batch_size=context["training_parameters"]["batch_size"], shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Loop across batches preds_list, slice_idx_list = [], [] last_sample_bool, weight_matrix, volume = False, None, None for i_batch, batch in enumerate(data_loader): preds = get_preds(context, fname_model, model_params, gpu_id, batch) # Set datatype to gt since prediction should be processed the same way as gt for b in batch['input_metadata']: for modality in b: modality['data_type'] = 'gt' # Reconstruct 3D object pred_list, target_list, last_sample_bool, weight_matrix, volume = reconstruct_3d_object( context, batch, undo_transforms, preds, preds_list, kernel_3D, slice_axis, slice_idx_list, data_loader, fname_images, i_batch, last_sample_bool, weight_matrix, volume) return pred_list, target_list
def segment_volume(folder_model, fname_images, gpu_number=0, options=None): """Segment an image. Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest (`fname_roi`) is used to crop the image prior to segment it. Args: folder_model (str): foldername which contains (1) the model ('folder_model/folder_model.pt') to use (2) its configuration file ('folder_model/folder_model.json') used for the training, see https://github.com/neuropoly/ivadomed/wiki/configuration-file fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple images to segment, e.i., len(fname_images) > 1. gpu_number (int): Number representing gpu number if available. options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS lesion classification) e.g., spinal cord centerline, used to crop the image prior to segment it if provided. The segmentation is not performed on the slices that are empty in this image. Returns: list: List of nibabel objects containing the soft segmentation(s), one per prediction class. list: List of target suffix associated with each prediction in `pred_list` """ # Define device cuda_available = torch.cuda.is_available() device = torch.device("cpu") if not cuda_available else torch.device( "cuda:" + str(gpu_number)) # Check if model folder exists and get filenames fname_model, fname_model_metadata = imed_models.get_model_filenames( folder_model) # Load model training config context = imed_config_manager.ConfigurationManager( fname_model_metadata).get_config() postpro_list = [ 'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small' ] if options is not None and any(pp in options for pp in postpro_list): postpro = {} if 'binarize_prediction' in options and options['binarize_prediction']: postpro['binarize_prediction'] = { "thr": options['binarize_prediction'] } if 'keep_largest' in options and options['keep_largest'] is not None: if options['keep_largest']: postpro['keep_largest'] = {} # Remove key in context if value set to 0 elif 'keep_largest' in context['postprocessing']: del context['postprocessing']['keep_largest'] if 'fill_holes' in options and options['fill_holes'] is not None: if options['fill_holes']: postpro['fill_holes'] = {} # Remove key in context if value set to 0 elif 'fill_holes' in context['postprocessing']: del context['postprocessing']['fill_holes'] if 'remove_small' in options and options['remove_small'] and \ ('mm' in options['remove_small'][-1] or 'vox' in options['remove_small'][-1]): unit = 'mm3' if 'mm3' in options['remove_small'][-1] else 'vox' thr = [int(t.replace(unit, "")) for t in options['remove_small']] postpro['remove_small'] = {"unit": unit, "thr": thr} context['postprocessing'].update(postpro) # LOADER loader_params = context["loader_parameters"] slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']] metadata = {} fname_roi = None fname_prior = options['fname_prior'] if (options is not None) and ( 'fname_prior' in options) else None if fname_prior is not None: if 'roi_params' in loader_params and loader_params['roi_params'][ 'suffix'] is not None: fname_roi = fname_prior # TRANSFORMATIONS # If ROI is not provided then force center cropping if fname_roi is None and 'ROICrop' in context["transformation"].keys(): print( "\n WARNING: fname_roi has not been specified, then a cropping around the center of the image is " "performed instead of a cropping around a Region of Interest.") context["transformation"] = dict( (key, value) if key != 'ROICrop' else ('CenterCrop', value) for (key, value) in context["transformation"].items()) if 'object_detection_params' in context and \ context['object_detection_params']['object_detection_path'] is not None: imed_obj_detect.bounding_box_prior( fname_prior, metadata, slice_axis, context['object_detection_params']['safety_factor']) metadata = [metadata] * len(fname_images) # Compose transforms _, _, transform_test_params = imed_transforms.get_subdatasets_transforms( context["transformation"]) tranform_lst, undo_transforms = imed_transforms.prepare_transforms( transform_test_params) # Force filter_empty_mask to False if fname_roi = None if fname_roi is None and 'filter_empty_mask' in loader_params[ "slice_filter_params"]: print( "\nWARNING: fname_roi has not been specified, then the entire volume is processed." ) loader_params["slice_filter_params"]["filter_empty_mask"] = False filename_pairs = [(fname_images, None, fname_roi, metadata if isinstance(metadata, list) else [metadata])] kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \ not context['default_model']['is_2d'] if kernel_3D: ds = imed_loader.MRI3DSubVolumeSegmentationDataset( filename_pairs, transform=tranform_lst, length=context["Modified3DUNet"]["length_3D"], stride=context["Modified3DUNet"]["stride_3D"]) else: ds = imed_loader.MRI2DSegmentationDataset( filename_pairs, slice_axis=slice_axis, cache=True, transform=tranform_lst, slice_filter_fn=imed_loader_utils.SliceFilter( **loader_params["slice_filter_params"])) ds.load_filenames() if kernel_3D: print("\nLoaded {} {} volumes of shape {}.".format( len(ds), loader_params['slice_axis'], context['Modified3DUNet']['length_3D'])) else: print("\nLoaded {} {} slices.".format(len(ds), loader_params['slice_axis'])) model_params = {} if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']: metadata_dict = joblib.load( os.path.join(folder_model, 'metadata_dict.joblib')) for idx in ds.indexes: for i in range(len(idx)): idx[i]['input_metadata'][0][context['FiLMedUnet'] ['metadata']] = options['metadata'] idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict ds = imed_film.normalize_metadata(ds, None, context["debugging"], context['FiLMedUnet']['metadata']) onehotencoder = joblib.load( os.path.join(folder_model, 'one_hot_encoder.joblib')) model_params.update({ "name": 'FiLMedUnet', "film_onehotencoder": onehotencoder, "n_metadata": len([ll for l in onehotencoder.categories_ for ll in l]) }) # Data Loader data_loader = DataLoader( ds, batch_size=context["training_parameters"]["batch_size"], shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # MODEL if fname_model.endswith('.pt'): model = torch.load(fname_model, map_location=device) # Inference time model.eval() # Loop across batches preds_list, slice_idx_list = [], [] last_sample_bool, volume, weight_matrix = False, None, None for i_batch, batch in enumerate(data_loader): with torch.no_grad(): img = imed_utils.cuda(batch['input'], cuda_available=cuda_available) if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \ ('HeMISUnet' in context and context['HeMISUnet']['applied']): metadata = imed_training.get_metadata(batch["input_metadata"], model_params) preds = model(img, metadata) else: preds = model(img) if fname_model.endswith( '.pt') else onnx_inference(fname_model, img) preds = preds.cpu() # Set datatype to gt since prediction should be processed the same way as gt for b in batch['input_metadata']: for modality in b: modality['data_type'] = 'gt' # Reconstruct 3D object for i_slice in range(len(preds)): if "bounding_box" in batch['input_metadata'][i_slice][0]: imed_obj_detect.adjust_undo_transforms( undo_transforms.transforms, batch, i_slice) batch['gt_metadata'] = [[metadata[0]] * preds.shape[1] for metadata in batch['input_metadata']] if kernel_3D: preds_undo, metadata, last_sample_bool, volume, weight_matrix = \ volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix) preds_list = [np.array(preds_undo)] else: # undo transformations preds_i_undo, metadata_idx = undo_transforms( preds[i_slice], batch["input_metadata"][i_slice], data_type='gt') # Add new segmented slice to preds_list preds_list.append(np.array(preds_i_undo)) # Store the slice index of preds_i_undo in the original 3D image slice_idx_list.append( int(batch['input_metadata'][i_slice][0]['slice_index'])) # If last batch and last sample of this batch, then reconstruct 3D object if (i_batch == len(data_loader) - 1 and i_slice == len(batch['gt']) - 1) or last_sample_bool: pred_nib = pred_to_nib( data_lst=preds_list, fname_ref=fname_images[0], fname_out=None, z_lst=slice_idx_list, slice_axis=slice_axis, kernel_dim='3d' if kernel_3D else '2d', debug=False, bin_thr=-1, postprocessing=context['postprocessing']) pred_list = split_classes(pred_nib) target_list = context['loader_parameters']['target_suffix'] return pred_list, target_list
def segment_volume(folder_model, fname_image, fname_prior=None, gpu_number=0): """Segment an image. Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest (`fname_roi`) is used to crop the image prior to segment it. Args: folder_model (str): foldername which contains (1) the model ('folder_model/folder_model.pt') to use (2) its configuration file ('folder_model/folder_model.json') used for the training, see https://github.com/neuropoly/ivadomed/wiki/configuration-file fname_image (str): image filename (e.g. .nii.gz) to segment. fname_prior (str): Image filename (e.g. .nii.gz) containing processing information (e.i. spinal cord segmentation, spinal location or MS lesion classification) e.g. spinal cord centerline, used to crop the image prior to segment it if provided. The segmentation is not performed on the slices that are empty in this image. gpu_number (int): Number representing gpu number if available. Returns: nibabelObject: Object containing the soft segmentation. """ # Define device cuda_available = torch.cuda.is_available() device = torch.device("cpu") if not cuda_available else torch.device( "cuda:" + str(gpu_number)) # Check if model folder exists and get filenames fname_model, fname_model_metadata = imed_models.get_model_filenames( folder_model) # Load model training config with open(fname_model_metadata, "r") as fhandle: context = json.load(fhandle) # LOADER loader_params = context["loader_parameters"] slice_axis = AXIS_DCT[loader_params['slice_axis']] metadata = {} fname_roi = None if fname_prior is not None: if 'roi_params' in loader_params and loader_params['roi_params'][ 'suffix'] is not None: fname_roi = fname_prior # TRANSFORMATIONS # If ROI is not provided then force center cropping if fname_roi is None and 'ROICrop' in context["transformation"].keys(): print( "\nWARNING: fname_roi has not been specified, then a cropping around the center of the image is performed" " instead of a cropping around a Region of Interest.") context["transformation"] = dict( (key, value) if key != 'ROICrop' else ('CenterCrop', value) for (key, value) in context["transformation"].items()) if 'object_detection_params' in context and \ context['object_detection_params']['object_detection_path'] is not None: imed_obj_detect.bounding_box_prior(fname_prior, metadata, slice_axis) # Compose transforms _, _, transform_test_params = imed_transforms.get_subdatasets_transforms( context["transformation"]) tranform_lst, undo_transforms = imed_transforms.prepare_transforms( transform_test_params) # Force filter_empty_mask to False if fname_roi = None if fname_roi is None and 'filter_empty_mask' in loader_params[ "slice_filter_params"]: print( "\nWARNING: fname_roi has not been specified, then the entire volume is processed." ) loader_params["slice_filter_params"]["filter_empty_mask"] = False filename_pairs = [([fname_image], None, fname_roi, [metadata])] kernel_3D = bool('UNet3D' in context and context['UNet3D']['applied']) if kernel_3D: ds = imed_loader.MRI3DSubVolumeSegmentationDataset( filename_pairs, transform=tranform_lst, length=context["UNet3D"]["length_3D"], stride=context["UNet3D"]["stride_3D"]) else: ds = imed_loader.MRI2DSegmentationDataset( filename_pairs, slice_axis=slice_axis, cache=True, transform=tranform_lst, slice_filter_fn=SliceFilter( **loader_params["slice_filter_params"])) ds.load_filenames() if kernel_3D: print("\nLoaded {} {} volumes of shape {}.".format( len(ds), loader_params['slice_axis'], context['UNet3D']['length_3D'])) else: print("\nLoaded {} {} slices.".format(len(ds), loader_params['slice_axis'])) # Data Loader data_loader = DataLoader( ds, batch_size=context["training_parameters"]["batch_size"], shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # MODEL if fname_model.endswith('.pt'): model = torch.load(fname_model, map_location=device) # Inference time model.eval() # Loop across batches preds_list, slice_idx_list = [], [] last_sample_bool, volume, weight_matrix = False, None, None for i_batch, batch in enumerate(data_loader): with torch.no_grad(): img = cuda(batch['input'], cuda_available=cuda_available) preds = model(img) if fname_model.endswith( '.pt') else onnx_inference(fname_model, img) preds = preds.cpu() # Set datatype to gt since prediction should be processed the same way as gt for modality in batch['input_metadata']: modality[0]['data_type'] = 'gt' # Reconstruct 3D object for i_slice in range(len(preds)): if "bounding_box" in batch['input_metadata'][i_slice][0]: imed_obj_detect.adjust_undo_transforms( undo_transforms.transforms, batch, i_slice) if kernel_3D: batch['gt_metadata'] = batch['input_metadata'] preds_undo, metadata, last_sample_bool, volume, weight_matrix = \ volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix) preds_list = [np.array(preds_undo)] else: # undo transformations preds_i_undo, metadata_idx = undo_transforms( preds[i_slice], batch["input_metadata"][i_slice], data_type='gt') # Add new segmented slice to preds_list preds_list.append(np.array(preds_i_undo)) # Store the slice index of preds_i_undo in the original 3D image slice_idx_list.append( int(batch['input_metadata'][i_slice][0]['slice_index'])) # If last batch and last sample of this batch, then reconstruct 3D object if (i_batch == len(data_loader) - 1 and i_slice == len(batch['gt']) - 1) or last_sample_bool: pred_nib = pred_to_nib(data_lst=preds_list, fname_ref=fname_image, fname_out=None, z_lst=slice_idx_list, slice_axis=slice_axis, kernel_dim='3d' if kernel_3D else '2d', debug=False, bin_thr=-1) return pred_nib