def save_feature_map(batch, layer_name, path_output, model, test_input, slice_axis): """Save model feature maps. Args: batch (dict): layer_name (str): path_output (str): Output folder. model (nn.Module): Network. test_input (Tensor): slice_axis (int): Indicates the axis used for the 2D slice extraction: Sagittal: 0, Coronal: 1, Axial: 2. """ if not Path(path_output, layer_name).exists(): Path(path_output, layer_name).mkdir() # Save for subject in batch for i in range(batch['input'].size(0)): inp_fmap, out_fmap = \ HookBasedFeatureExtractor(model, layer_name, False).forward(Variable(test_input[i][None,])) # Display the input image and Down_sample the input image orig_input_img = test_input[i][None, ].cpu().numpy() upsampled_attention = F.interpolate( out_fmap[1], size=test_input[i][None, ].size()[2:], mode='trilinear', align_corners=True).data.cpu().numpy() path = batch["input_metadata"][0][i]["input_filenames"] basename = path.split('/')[-1] save_directory = Path(path_output, layer_name, basename) # Write the attentions to a nifti image nib_ref = nib.load(path) nib_ref_can = nib.as_closest_canonical(nib_ref) oriented_image = imed_loader_utils.reorient_image( orig_input_img[0, 0, :, :, :], slice_axis, nib_ref, nib_ref_can) nib_pred = nib.Nifti1Image(dataobj=oriented_image, affine=nib_ref.header.get_best_affine(), header=nib_ref.header.copy()) nib.save(nib_pred, save_directory) basename = basename.split(".")[0] + "_att.nii.gz" save_directory = Path(path_output, layer_name, basename) attention_map = imed_loader_utils.reorient_image( upsampled_attention[0, 0, :, :, :], slice_axis, nib_ref, nib_ref_can) nib_pred = nib.Nifti1Image(dataobj=attention_map, affine=nib_ref.header.get_best_affine(), header=nib_ref.header.copy()) nib.save(nib_pred, save_directory)
def get_midslice_average(path_im, ind, slice_axis=0): """ Extract an average 2D slice out of a 3D volume. This image is generated by averaging the 7 slices in the middle of the volume Args: path_im (string): path to image ind (int): index of the slice around which we will average slice_axis (int): Slice axis according to RAS convention Returns: nifti: a single slice nifti object containing the average image in the image space. """ image = nib.load(path_im) image_can = nib.as_closest_canonical(image) arr_can = np.array(image_can.dataobj) numb_of_slice = 3 # Avoid out of bound error by changing the number of slice taken if needed if ind + 3 > arr_can.shape[slice_axis]: numb_of_slice = arr_can.shape[slice_axis] - ind if ind - numb_of_slice < 0: numb_of_slice = ind slc = [slice(None)] * len(arr_can.shape) slc[slice_axis] = slice(ind - numb_of_slice, ind + numb_of_slice) mid = np.mean(arr_can[tuple(slc)], slice_axis) arr_pred_ref_space = imed_loader_utils.reorient_image( np.expand_dims(mid[:, :], axis=slice_axis), 2, image, image_can).astype('float32') nib_pred = nib.Nifti1Image(dataobj=arr_pred_ref_space, affine=image.header.get_best_affine(), header=image.header.copy()) return nib_pred
def extract_mid_slice_and_convert_coordinates_to_heatmaps( path, suffix, aim=-1): """ This function takes as input a path to a dataset and generates a set of images: (i) mid-sagittal image and (ii) heatmap of disc labels associated with the mid-sagittal image. Example:: ivadomed_prepare_dataset_vertebral_labeling -p path/to/bids -s _T2w -a 0 Args: path (string): path to BIDS dataset form which images will be generated. Flag: ``--path``, ``-p`` suffix (string): suffix of image that will be processed (e.g., T2w). Flag: ``--suffix``, ``-s`` aim (int): If aim is not 0, retrieves only labels with value = aim, else create heatmap with all labels. Flag: ``--aim``, ``-a`` Returns: None. Images are saved in BIDS folder """ t = [ path_object.name for path_object in Path(path).iterdir() if path_object.name != 'derivatives' ] for i in range(len(t)): subject = t[i] path_image = Path(path, subject, 'anat', subject + suffix + '.nii.gz') if path_image.is_file(): path_label = Path(path, 'derivatives', 'labels', subject, 'anat', subject + suffix + '_labels-disc-manual.nii.gz') list_points = mask2label(str(path_label), aim=aim) image_ref = nib.load(path_image) nib_ref_can = nib.as_closest_canonical(image_ref) imsh = np.array(nib_ref_can.dataobj).shape mid_nifti = imed_preprocessing.get_midslice_average( str(path_image), list_points[0][0], slice_axis=0) nib.save( mid_nifti, Path(path, subject, 'anat', subject + suffix + '_mid.nii.gz')) lab = nib.load(path_label) nib_ref_can = nib.as_closest_canonical(lab) label_array = np.zeros(imsh[1:]) for j in range(len(list_points)): label_array[list_points[j][1], list_points[j][2]] = 1 heatmap = imed_maths.heatmap_generation(label_array[:, :], 10) arr_pred_ref_space = imed_loader_utils.reorient_image( np.expand_dims(heatmap[:, :], axis=0), 2, lab, nib_ref_can) nib_pred = nib.Nifti1Image(arr_pred_ref_space, lab.affine) nib.save( nib_pred, Path(path, 'derivatives', 'labels', subject, 'anat', subject + suffix + '_mid_heatmap' + str(aim) + '.nii.gz')) else: pass
def pred_to_nib(data_lst, z_lst, fname_ref, fname_out, slice_axis, debug=False, kernel_dim='2d', bin_thr=0.5, discard_noise=True, postprocessing=None): """Save the network predictions as nibabel object. Based on the header of `fname_ref` image, it creates a nibabel object from the Network predictions (`data_lst`). Args: data_lst (list of np arrays): Predictions, either 2D slices either 3D patches. z_lst (list of ints): Slice indexes to reconstruct a 3D volume for 2D slices. fname_ref (str): Filename of the input image: its header is copied to the output nibabel object. fname_out (str): If not None, then the generated nibabel object is saved with this filename. slice_axis (int): Indicates the axis used for the 2D slice extraction: Sagittal: 0, Coronal: 1, Axial: 2. debug (bool): If True, extended verbosity and intermediate outputs. kernel_dim (str): Indicates whether the predictions were done on 2D or 3D patches. Choices: '2d', '3d'. bin_thr (float): If positive, then the segmentation is binarized with this given threshold. Otherwise, a soft segmentation is output. discard_noise (bool): If True, predictions that are lower than 0.01 are set to zero. postprocessing (dict): Contains postprocessing steps to be applied. Returns: NibabelObject: Object containing the Network prediction. """ # Load reference nibabel object nib_ref = nib.load(fname_ref) nib_ref_can = nib.as_closest_canonical(nib_ref) if kernel_dim == '2d': # complete missing z with zeros tmp_lst = [] for z in range(nib_ref_can.header.get_data_shape()[slice_axis]): if not z in z_lst: tmp_lst.append(np.zeros(data_lst[0].shape)) else: tmp_lst.append(data_lst[z_lst.index(z)]) if debug: print("Len {}".format(len(tmp_lst))) for arr in tmp_lst: print("Shape element lst {}".format(arr.shape)) # create data and stack on depth dimension arr_pred_ref_space = np.stack(tmp_lst, axis=-1) else: arr_pred_ref_space = data_lst[0] n_channel = arr_pred_ref_space.shape[0] oriented_volumes = [] if len(arr_pred_ref_space.shape) == 4: for i in range(n_channel): oriented_volumes.append( imed_loader_utils.reorient_image(arr_pred_ref_space[i, ], slice_axis, nib_ref, nib_ref_can)) # transpose to locate the channel dimension at the end to properly see image on viewer arr_pred_ref_space = np.asarray(oriented_volumes).transpose( (1, 2, 3, 0)) else: arr_pred_ref_space = imed_loader_utils.reorient_image( arr_pred_ref_space, slice_axis, nib_ref, nib_ref_can) if bin_thr >= 0: arr_pred_ref_space = imed_postpro.threshold_predictions( arr_pred_ref_space, thr=bin_thr) elif discard_noise: # discard noise arr_pred_ref_space[arr_pred_ref_space <= 1e-2] = 0 # create nibabel object if postprocessing: fname_prefix = fname_out.split( "_pred.nii.gz")[0] if fname_out is not None else None postpro = imed_postpro.Postprocessing(postprocessing, arr_pred_ref_space, nib_ref.header['pixdim'][1:4], fname_prefix) arr_pred_ref_space = postpro.apply() nib_pred = nib.Nifti1Image(arr_pred_ref_space, nib_ref.affine) # save as nifti file if fname_out is not None: nib.save(nib_pred, fname_out) return nib_pred
def test_image_orientation(download_data_testing_test_files, loader_parameters): device = torch.device("cuda:" + str(GPU_ID) if torch.cuda.is_available() else "cpu") cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.set_device(device) logger.info(f"Using GPU ID {device}") bids_df = BidsDataframe(loader_parameters, __tmp_dir__, derivatives=True) contrast_params = loader_parameters["contrast_params"] target_suffix = loader_parameters["target_suffix"] roi_params = loader_parameters["roi_params"] train_lst = ['sub-unf01_T1w.nii.gz'] training_transform_dict = { "Resample": { "wspace": 1.5, "hspace": 1, "dspace": 3 }, "CenterCrop": { "size": [176, 128, 160] }, "NormalizeInstance": { "applied_to": ['im'] } } tranform_lst, training_undo_transform = imed_transforms.prepare_transforms( training_transform_dict) model_params = { "name": "Modified3DUNet", "dropout_rate": 0.3, "bn_momentum": 0.9, "depth": 2, "in_channel": 1, "out_channel": 1, "length_3D": [176, 128, 160], "stride_3D": [176, 128, 160], "attention": False, "n_filters": 8 } for dim in ['2d', '3d']: for slice_axis in [0, 1, 2]: if dim == '2d': ds = BidsDataset(bids_df=bids_df, subject_file_lst=train_lst, target_suffix=target_suffix, contrast_params=contrast_params, model_params=model_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) ds.load_filenames() else: ds = Bids3DDataset(bids_df=bids_df, subject_file_lst=train_lst, target_suffix=target_suffix, model_params=model_params, contrast_params=contrast_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) loader = DataLoader(ds, batch_size=1, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[ 0] segpair = SegmentationPair(input_filename, gt_filename, metadata=metadata, slice_axis=slice_axis) nib_original = nib.load(gt_filename[0]) # Get image with original, ras and hwd orientations input_init = nib_original.get_fdata() input_ras = nib.as_closest_canonical(nib_original).get_fdata() img, gt = segpair.get_pair_data() input_hwd = gt[0] pred_tmp_lst, z_tmp_lst = [], [] for i, batch in enumerate(loader): # batch["input_metadata"] = batch["input_metadata"][0] # Take only metadata from one input # batch["gt_metadata"] = batch["gt_metadata"][0] # Take only metadata from one label for smp_idx in range(len(batch['gt'])): # undo transformations if dim == '2d': preds_idx_undo, metadata_idx = training_undo_transform( batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') # add new sample to pred_tmp_lst pred_tmp_lst.append(preds_idx_undo[0]) z_tmp_lst.append( int(batch['input_metadata'][smp_idx][0] ['slice_index'])) else: preds_idx_undo, metadata_idx = training_undo_transform( batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') fname_ref = metadata_idx[0]['gt_filenames'][0] if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d': # save the completely processed file as a nii nib_ref = nib.load(fname_ref) nib_ref_can = nib.as_closest_canonical(nib_ref) if dim == '2d': tmp_lst = [] for z in range(nib_ref_can.header.get_data_shape() [slice_axis]): tmp_lst.append( pred_tmp_lst[z_tmp_lst.index(z)]) arr = np.stack(tmp_lst, axis=-1) else: arr = np.array(preds_idx_undo[0]) # verify image after transform, undo transform and 3D reconstruction input_hwd_2 = imed_postpro.threshold_predictions(arr) # Some difference are generated due to transform and undo transform # (e.i. Resample interpolation) assert imed_metrics.dice_score(input_hwd_2, input_hwd) >= 0.8 input_ras_2 = imed_loader_utils.orient_img_ras( input_hwd_2, slice_axis) assert imed_metrics.dice_score(input_ras_2, input_ras) >= 0.8 input_init_2 = imed_loader_utils.reorient_image( input_hwd_2, slice_axis, nib_ref, nib_ref_can) assert imed_metrics.dice_score(input_init_2, input_init) >= 0.8 # re-init pred_stack_lst pred_tmp_lst, z_tmp_lst = [], []