def __getitem__(self, idx): """ Extract the CT and corresponding mask sepcified by idx. ---------- INPUT |---- idx (int) the sample index in self.data_df. OUTPUT |---- slice (torch.tensor) the CT image with dimension (1 x H x W). |---- mask (torch.tensor) the segmentation mask with dimension (1 x H x W). |---- patient_nbr (torch.tensor) the patient id as a single value. |---- slice_nbr (torch.tensor) the slice number as a single value. """ # load image slice = io.imread(self.data_path + self.data_df.iloc[idx].CT_fn) if self.window: slice = window_ct(slice, win_center=self.window[0], win_width=self.window[1], out_range=(0, 1)) # load mask if one, else make a blank array if self.data_df.iloc[idx].mask_fn == 'None': mask = np.zeros_like(slice) else: mask = io.imread(self.data_path + self.data_df.iloc[idx].mask_fn) # get the patient id patient_nbr = torch.tensor(self.data_df.iloc[idx].PatientNumber) # get slice number slice_nbr = torch.tensor(self.data_df.iloc[idx].SliceNumber) # Apply the transform : Data Augmentation + image formating slice, mask = self.transform(slice, mask) return slice, mask, patient_nbr, slice_nbr
def __getitem__(self, idx): """ Extract the CT sepcified by idx. ---------- INPUT |---- idx (int) the sample index in self.data_df. OUTPUT |---- im (torch.tensor) the CT image with dimension (1 x H x W). |---- mask (torch.tensor) the inpaining mask with dimension (1 x H x W). """ # load dicom and recover the CT pixel values dcm_im = pydicom.dcmread(self.data_path + self.data_df.iloc[idx].filename) im = (dcm_im.pixel_array * float(dcm_im.RescaleSlope) + float(dcm_im.RescaleIntercept)) # Window the CT-scan if self.window: im = window_ct(im, win_center=self.window[0], win_width=self.window[1], out_range=(0, 1)) # transform image im = self.transform(im) label = self.data_df.iloc[idx].Hemorrhage if self.artificial_anomaly and (np.random.rand() < self.anomaly_proba) and (label == 0): # get a mask anomalies = tf.ToTorchTensor()(self.draw_ellipses( (im.shape[1], im.shape[2]), **self.drawing_params)) im = torch.where(anomalies > 0, anomalies, im) label = 1 return im, torch.tensor(label), torch.tensor(idx)
def __getitem__(self, idx): """ Extract the CT sepcified by idx. ---------- INPUT |---- idx (int) the sample index in self.data_df. OUTPUT |---- im (torch.tensor) the CT image with dimension (1 x H x W). |---- mask (torch.tensor) the inpaining mask with dimension (1 x H x W). """ # load dicom and recover the CT pixel values dcm_im = pydicom.dcmread(self.data_path + self.data_df.iloc[idx].filename) im = (dcm_im.pixel_array * float(dcm_im.RescaleSlope) + float(dcm_im.RescaleIntercept)) # Window the CT-scan if self.window: im = window_ct(im, win_center=self.window[0], win_width=self.window[1], out_range=(0, 1)) # transform image im = self.transform(im) # get a mask mask = self.random_ff_mask((im.shape[1], im.shape[2])) return im, tf.ToTorchTensor()(mask)
def __getitem__(self, idx): """ Extract the image and mask sepcified by idx. ---------- INPUT |---- idx (int) the sample index in self.data_df. OUTPUT |---- im (torch.tensor) the image with dimension (1 x H x W). |---- mask (torch.tensor) the mask with dimension (1 x H x W). |---- idx (torch.tensor) the data index in the data_df. """ # load dicom and recover the CT pixel values im = io.imread( os.path.join(self.data_path, self.data_df.iloc[idx].im_fn)) mask = io.imread( os.path.join(self.data_path, self.data_df.iloc[idx].mask_fn)) # Window the CT-scan if self.window: im = window_ct(im, win_center=self.window[0], win_width=self.window[1], out_range=(0, 1)) # transform image im, mask = self.transform(im, mask) return im, mask, torch.tensor(idx)
def main(input_data_path, output_data_path, window): """ Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap for the brain extraction. """ # open data info dataframe info_df = pd.read_csv(os.path.join(input_data_path, 'info.csv'), index_col=0) # make patient directory if not os.path.exists(output_data_path): os.mkdir(output_data_path) # iterate over volume to extract data output_info = [] for n, id in enumerate(info_df.id.values): # read nii volume ct_nii = nib.load(os.path.join(input_data_path, f'ct_scans/{id}.nii')) mask_nii = nib.load(os.path.join(input_data_path, f'masks/{id}.nii.gz')) # get np.array ct_vol = ct_nii.get_fdata() mask_vol = skimage.img_as_bool(mask_nii.get_fdata()) # rotate 90° counter clockwise for head pointing upward ct_vol = np.rot90(ct_vol, axes=(0,1)) mask_vol = np.rot90(mask_vol, axes=(0,1)) # window the ct volume to get better contrast of soft tissues if window is not None: ct_vol = window_ct(ct_vol, win_center=window[0], win_width=window[1], out_range=(0,1)) if mask_vol.shape != ct_vol.shape: print(f'>>> Warning! The ct volume of patient {id} does not have ' f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})') # make patient directory if not os.path.exists(os.path.join(output_data_path, f'{id:03}/ct/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/ct/')) if not os.path.exists(os.path.join(output_data_path, f'{id:03}/mask/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/mask/')) # iterate over slices to save slices for i, slice in enumerate(range(ct_vol.shape[2])): ct_slice_fn =f'{id:03}/ct/{slice+1}.tif' # save CT slice skimage.io.imsave(os.path.join(output_data_path, ct_slice_fn), ct_vol[:,:,slice], check_contrast=False) is_low = True if skimage.exposure.is_low_contrast(ct_vol[:,:,slice]) else False # save mask if some brain on slice if np.any(mask_vol[:,:,slice]): mask_slice_fn = f'{id:03}/mask/{slice+1}_Seg.bmp' skimage.io.imsave(os.path.join(output_data_path, mask_slice_fn), skimage.img_as_ubyte(mask_vol[:,:,slice]), check_contrast=False) else: mask_slice_fn = 'None' # add info to output list output_info.append({'volume':id, 'slice':slice+1, 'ct_fn':ct_slice_fn, 'mask_fn':mask_slice_fn, 'low_contrast_ct':is_low}) print_progessbar(i, ct_vol.shape[2], Name=f'Volume {id:03} {n+1:03}/{len(info_df.id):03}', Size=20, erase=False) # Make dataframe of outputs output_info_df = pd.DataFrame(output_info) # save df output_info_df.to_csv(os.path.join(output_data_path, 'slice_info.csv')) print('>>> Slice informations saved at ' + os.path.join(output_data_path, 'slice_info.csv')) # save patient df info_df.to_csv(os.path.join(output_data_path, 'volume_info.csv')) print('>>> Volume informations saved at ' + os.path.join(output_data_path, 'volume_info.csv'))
def __getitem__(self, idx): """ Extract the CT sepcified by idx. ---------- INPUT |---- idx (int) the sample index in self.data_df. OUTPUT |---- im (torch.tensor) the CT image with dimension (1 x H x W). |---- lab (torch.tensor) the label for hemorrhage presence (0 or 1). |---- idx (torch.tensor) the sample idx. """ # load dicom and recover the CT pixel values dcm_im = pydicom.dcmread(self.data_path + self.data_df.iloc[idx].filename) im = (dcm_im.pixel_array * float(dcm_im.RescaleSlope) + float(dcm_im.RescaleIntercept)) # Window the CT-scan if self.window: im = window_ct(im, win_center=self.window[0], win_width=self.window[1], out_range=(0, 1)) if self.mode == 'standard': # transform image im = self.transform(im) return self.toTensor(im), torch.tensor( idx) #torch.tensor(lab), torch.tensor(idx) elif self.mode == 'context_restoration': # generate corrupeted version # transform image im = self.transform(im) swapped_im = self.swap_tranform(im) return self.toTensor(im), self.toTensor(swapped_im), torch.tensor( idx) elif self.mode == 'contrastive': # augmente image twice im1 = self.contrastive_transform(self.transform(im)) im2 = self.contrastive_transform(self.transform(im)) return self.toTensor(im1), self.toTensor(im2), torch.tensor(idx) elif self.mode == 'binary_classification': im = self.transform(im) label = self.data_df.iloc[idx].Hemorrhage return self.toTensor(im), torch.tensor(label), torch.tensor(idx) elif self.mode == 'multi_classification': im = self.transform(im) samp = self.data_df.iloc[idx] label = [samp[name] for name in self.class_name] return self.toTensor(im), torch.tensor(label), torch.tensor(idx)
def __getitem__(self, idx): """ Extract the stacked CT and attention map, the corresponding ground truth mask, volume id, and slice number sepcified by idx. ---------- INPUT |---- idx (int) the sample index in self.data_df. OUTPUT |---- input (torch.tensor) the CT image stacked with the attention map with dimension (2 x H x W). |---- mask (torch.tensor) the segmentation mask with dimension (1 x H x W). |---- patient_nbr (torch.tensor) the patient id as a single value. |---- slice_nbr (torch.tensor) the slice number as a single value. """ # load image slice = io.imread(self.data_path + self.data_df.iloc[idx].ct_fn) if self.window: slice = window_ct(slice, win_center=self.window[0], win_width=self.window[1], out_range=(0, 1)) # load attention map and stack it with the slice if self.data_df.iloc[idx].attention_fn == 'None': attention_map = np.zeros_like(slice) else: attention_map = skimage.img_as_float( io.imread(self.data_path + self.data_df.iloc[idx].attention_fn)) attention_map = skimage.transform.resize(attention_map, slice.shape[:2], order=1, preserve_range=True) input = np.stack([slice, attention_map], axis=2) # load mask if one, else make a blank array if self.data_df.iloc[idx].mask_fn == 'None': mask = np.zeros_like(slice) else: mask = io.imread(self.data_path + self.data_df.iloc[idx].mask_fn) # get the patient id patient_nbr = torch.tensor(self.data_df.iloc[idx].id) # get slice number slice_nbr = torch.tensor(self.data_df.iloc[idx].slice) # Apply the transform : Data Augmentation + image formating input, mask = self.transform(input, mask) return input, mask, patient_nbr, slice_nbr
def __getitem__(self, idx): """ Get the CT-volumes of the given patient idx. ---------- INPUT |---- idx (int) the patient index in self.PatientID_list to extract. OUTPUT |---- volume (torch.Tensor) the CT-volume in a tensor (H x W x Slice) """ # load data ct_nii = nib.load(self.data_path + self.data_df.loc[idx, 'CT_fn']) mask_nii = nib.load(self.data_path + self.data_df.loc[idx, 'mask_fn']) pID = torch.tensor(self.data_df.loc[idx, 'PatientNumber']) # get volumes and pixel dimension ct_vol = np.rot90(ct_nii.get_fdata(), axes=(0, 1)) mask = np.rot90(mask_nii.get_fdata(), axes=(0, 1)) pix_dim = ct_nii.header['pixdim'][ 1:4] # recover pixel physical dimension # window CT-scan for soft tissus ct_vol = window_ct(ct_vol, win_center=self.window[0], win_width=self.window[1], out_range=(0, 1)) # resample vol and mask ct_vol = resample_ct(ct_vol, pix_dim, out_pixel_dim=self.resampling_dim, preserve_range=True, order=self.resampling_order) mask = resample_ct(mask, pix_dim, out_pixel_dim=self.resampling_dim, preserve_range=True, order=0) #self.resampling_order) ct_vol, mask = self.transform(ct_vol, mask) return ct_vol, mask.bool(), pID
def segement_volume(self, vol, save_fn=None, window=None, input_size=(256, 256), return_pred=False): """ Segement each slice of the passed Nifti volume and save the results as a Nifti volume. ---------- INPUT |---- vol (nibabel.nifti1.Nifti1Pair) the nibabel volume with metadata to segement. |---- save_fn (str) where to save the segmentation. |---- window (tuple (center, width)) the winowing to apply to the ct-scan. |---- input_size (tuple (h, w)) the input size for the network. |---- return_pred (bool) whether to return the volume of prediction. OUTPUT |---- (mask_vol) (nibabel.nifti1.Nifti1Pair) the prediction volume. """ pred_list = [] vol_data = np.rot90(vol.get_fdata(), axes=(0, 1)) # 90° counterclockwise rotation if window: vol_data = window_ct(vol_data, win_center=window[0], win_width=window[1], out_range=(0, 1)) transform = tf.Compose(tf.Resize(H=input_size[0], W=input_size[1]), tf.ToTorchTensor()) self.unet.eval() self.unet.to(self.device) with torch.no_grad(): for s in range(0, vol_data.shape[2], self.batch_size): # get slice in good size and as tensor input = transform(vol_data[:, :, s:s + self.batch_size]).to( self.device).float().permute(3, 0, 1, 2) # predict pred = self.unet(input) pred = torch.where(pred >= 0.5, torch.ones_like(pred, device=self.device), torch.zeros_like(pred, device=self.device)) # store pred (B x H x W) pred_list.append( pred.squeeze(dim=1).permute(1, 2, 0).cpu().numpy().astype( np.uint8) * 255) if self.print_progress: print_progessbar(s + pred.shape[0] - 1, Max=vol_data.shape[2], Name='Slice', Size=20, erase=True) # make the prediction volume vol_pred = np.concatenate(pred_list, axis=2) # resize to input size and rotate 90° clockwise vol_pred = np.rot90(skimage.transform.resize( vol_pred, (vol.header['dim'][1], vol.header['dim'][2]), order=0), axes=(1, 0)) # make Nifty and save it vol_pred_nii = nib.Nifti1Pair(vol_pred.astype(np.uint8), vol.affine) if save_fn: nib.save(vol_pred_nii, save_fn) # return Nifti prediction if return_pred: return vol_pred_nii
def main(input_data_path, output_data_path, window): """ Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap. """ # open data info dataframe info_df = pd.read_csv(input_data_path + 'hemorrhage_diagnosis_raw_ct.csv') # replace No-Hemorrhage to hemorrange info_df['Hemorrhage'] = 1 - info_df.No_Hemorrhage info_df.drop(columns='No_Hemorrhage', inplace=True) # open patient info dataframe patient_df = pd.read_csv(input_data_path + 'Patient_demographics.csv', header=1, skipfooter=2, engine='python') \ .rename(columns={'Unnamed: 0':'PatientNumber', 'Unnamed: 1':'Age', 'Unnamed: 2':'Gender', 'Unnamed: 8':'Fracture', 'Unnamed: 9':'Note'}) patient_df[patient_df.columns[3:9]] = patient_df[ patient_df.columns[3:9]].fillna(0).astype(int) # add columns Hemorrgae (any ICH) patient_df['Hemorrhage'] = patient_df[patient_df.columns[3:8]].max(axis=1) # make patient directory if not os.path.exists(output_data_path): os.mkdir(output_data_path) if not os.path.exists(output_data_path + 'Patient_CT/'): os.mkdir(output_data_path + 'Patient_CT/') # iterate over volume to extract data output_info = [] for n, id in enumerate(info_df.PatientNumber.unique()): # read nii volume ct_nii = nib.load(input_data_path + f'ct_scans/{id:03}.nii') mask_nii = nib.load(input_data_path + f'masks/{id:03}.nii') # get np.array ct_vol = ct_nii.get_fdata() mask_vol = skimage.img_as_bool(mask_nii.get_fdata()) # rotate 90° counter clockwise for head pointing upward ct_vol = np.rot90(ct_vol, axes=(0, 1)) mask_vol = np.rot90(mask_vol, axes=(0, 1)) # window the ct volume to get better contrast of soft tissues if window is not None: ct_vol = window_ct(ct_vol, win_center=window[0], win_width=window[1], out_range=(0, 1)) if mask_vol.shape != ct_vol.shape: print( f'>>> Warning! The ct volume of patient {id} does not have ' f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})' ) # make patient directory if not os.path.exists(output_data_path + f'Patient_CT/{id:03}/'): os.mkdir(output_data_path + f'Patient_CT/{id:03}/') # iterate over slices to save slices for i, slice in enumerate(range(ct_vol.shape[2])): ct_slice_fn = f'Patient_CT/{id:03}/{slice+1}.tif' # save CT slice skimage.io.imsave(output_data_path + ct_slice_fn, ct_vol[:, :, slice], check_contrast=False) is_low = True if skimage.exposure.is_low_contrast( ct_vol[:, :, slice]) else False # save mask if some positive ICH if np.any(mask_vol[:, :, slice]): mask_slice_fn = f'Patient_CT/{id:03}/{slice+1}_ICH_Seg.bmp' skimage.io.imsave(output_data_path + mask_slice_fn, skimage.img_as_ubyte(mask_vol[:, :, slice]), check_contrast=False) else: mask_slice_fn = 'None' # add info to output list output_info.append({ 'PatientNumber': id, 'SliceNumber': slice + 1, 'CT_fn': ct_slice_fn, 'mask_fn': mask_slice_fn, 'low_contrast_CT': is_low }) print_progessbar( i, ct_vol.shape[2], Name= f'Patient {id:03} {n+1:03}/{len(info_df.PatientNumber.unique()):03}', Size=20, erase=False) # Make dataframe of outputs output_info_df = pd.DataFrame(output_info) # Merge with input info info_df = pd.merge(info_df, output_info_df, how='inner', on=['PatientNumber', 'SliceNumber']) # save df info_df.to_csv(output_data_path + 'ct_info.csv') print('>>> Data informations saved at ' + output_data_path + 'ct_info.csv') # save patient df patient_df.to_csv(output_data_path + 'patient_info.csv') print('>>> Patient informations saved at ' + output_data_path + 'patient_info.csv')
def analyse_supervised_exp(exp_folder, data_path, n_fold, config_folder=None, save_fn='results_overview.pdf', is_brain_exp=False): """ Generate a summary figure of the supervised ICH segmentation experiment. """ if config_folder is None: config_folder = exp_folder # utility function def h_padcat(arr1, arr2): out_len = max([arr1.shape[0], arr2.shape[0]]) arr1_pad = np.pad(arr1, ((0, out_len - arr1.shape[0]), (0, 0)), constant_values=np.nan) arr2_pad = np.pad(arr2, ((0, out_len - arr2.shape[0]), (0, 0)), constant_values=np.nan) return np.concatenate([arr1_pad, arr2_pad], axis=1) ########## get data # get losses loss_list = [] for train_stat_fn in glob.glob( os.path.join(exp_folder, 'Fold_*/outputs.json')): with open(train_stat_fn, 'r') as fn: loss_list.append(np.array(json.load(fn)['train']['evolution'])) all = np.stack(loss_list, axis=2)[:, [1, 2, 3], :] data_evo = [ np.concatenate([ np.expand_dims(np.arange(1, all.shape[0] + 1), axis=1), all[:, i, :] ], axis=1) for i in range(all.shape[1]) ] # load performances results_df = pd.read_csv( os.path.join(exp_folder, 'all_volume_prediction.csv')) results_df = results_df.drop(columns=results_df.columns[0]) # load performances at slice level with open(os.path.join(config_folder, 'config.json'), 'r') as f: cfg = json.load(f) df_list = [] for i in range(n_fold): df_tmp = pd.read_csv( os.path.join(exp_folder, f'Fold_{i+1}/pred/slice_prediction_scores.csv')) df_tmp['Fold'] = i + 1 df_list.append(df_tmp) slice_df = pd.concat(df_list, axis=0).reset_index(drop=True) slice_df = slice_df.drop(columns=slice_df.columns[0]) ########## PLOT fontsize = 12 n_samp = 10 fig = plt.figure(figsize=(15, 12)) gs = fig.add_gridspec( nrows=7, ncols=n_samp, height_ratios=[0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15], hspace=0.4) # Loss Evolution Plot ax_evo = fig.add_subplot(gs[:2, :4]) colors_evo = ['black', 'tomato', 'dodgerblue' ] if not is_brain_exp else ['black', 'tomato', 'tomato'] serie_names_evo = [ 'Train Loss', 'Dice (all)', 'Dice (ICH)' ] if not is_brain_exp else ['All', 'Dice (all)', 'Dice (brain)'] if len(loss_list) > 0: curve_std(data_evo, serie_names_evo, colors=colors_evo, ax=ax_evo, lw=1, CI_alpha=0.05, rep_alpha=0.25, plot_rep=True, plot_mean=True, plot_CI=True, legend=True, legend_kwargs=dict(loc='upper left', ncol=3, frameon=False, framealpha=0.0, fontsize=fontsize, bbox_to_anchor=(0.0, -0.3), bbox_transform=ax_evo.transAxes)) ax_evo.set_xlabel('Epoch [-]', fontsize=fontsize) ax_evo.set_ylabel('Dice Loss [-] ; Dice Coeff. [-]', fontsize=fontsize) ax_evo.set_title('Training evolution', fontsize=fontsize, fontweight='bold', loc='left') ax_evo.set_xlim([1, data_evo[0].shape[0]]) ax_evo.set_ylim([0, 1]) ax_evo.tick_params(axis='both', labelsize=fontsize) ax_evo.spines['top'].set_visible(False) ax_evo.spines['right'].set_visible(False) # Conf Mat BarPlot ax_cm = fig.add_subplot(gs[1, 5:7]) ax_cm_bis = fig.add_subplot(gs[0, 5:7]) # make data data_cm = [ results_df[['TP', 'TN', 'FP', 'FN']].values, results_df.loc[results_df.label == 1, ['TP', 'TN', 'FP', 'FN']].values, results_df.loc[results_df.label == 0, ['TP', 'TN', 'FP', 'FN']].values ] serie_names_cm = [ 'All', 'ICH only', 'Non-ICH only' ] if not is_brain_exp else ['All', 'brain only', 'No_brain only'] group_names_cm = ['TP', 'TN', 'FP', 'FN'] colors_cm = ['tomato', 'dodgerblue', 'cornflowerblue'] metric_barplot(data_cm, serie_names=serie_names_cm, group_names=group_names_cm, colors=colors_cm, ax=ax_cm, fontsize=fontsize, jitter=True, jitter_color='gray', jitter_alpha=0.25, legend=True, legend_kwargs=dict(loc='upper left', ncol=1, frameon=False, framealpha=0.0, fontsize=fontsize, bbox_to_anchor=(0.0, -0.4), bbox_transform=ax_cm.transAxes), display_val=False) ax_cm.set_ylim([ 0, (results_df[['TP', 'FP', 'FN']].values.mean(axis=0) + 2.2 * results_df[['TP', 'FP', 'FN']].values.std(axis=0)).max() ]) ax_cm.set_ylabel('Count [-]', fontsize=fontsize) ax_cm.yaxis.set_major_formatter( matplotlib.ticker.FormatStrFormatter('%.1e')) # duplicate plot for long tail bar of True negative metric_barplot(data_cm, serie_names=serie_names_cm, group_names=group_names_cm, colors=colors_cm, ax=ax_cm_bis, fontsize=fontsize, jitter=True, jitter_color='gray', jitter_alpha=0.25, legend=False, legend_kwargs=None, display_val=False) ax_cm_bis.set_ylim(bottom=results_df['TN'].values.mean() - 2.2 * results_df['TN'].values.std()) ax_cm_bis.set_title('Volumetric Classification', fontsize=fontsize, fontweight='bold', loc='left') ax_cm_bis.spines['bottom'].set_visible(False) ax_cm_bis.yaxis.set_major_formatter( matplotlib.ticker.FormatStrFormatter('%.1e')) ax_cm_bis.set_xticklabels([]) ax_cm_bis.set_xticks([]) d = .015 ax_cm.plot((-d, +d), (1 - d, 1 + d), transform=ax_cm.transAxes, color='k', clip_on=False) ax_cm_bis.plot((-d, +d), (-d, +d), transform=ax_cm_bis.transAxes, color='k', clip_on=False) # Dice barplot ax_dice = fig.add_subplot(gs[0:2, 8:]) data_dice = [ h_padcat( h_padcat(results_df[['Dice']].values, results_df.loc[results_df.label == 1, ['Dice']].values), results_df.loc[results_df.label == 0, ['Dice']].values), h_padcat( h_padcat(slice_df[['Dice']].values, slice_df.loc[slice_df.label == 1, ['Dice']].values), slice_df.loc[slice_df.label == 0, ['Dice']].values) ] group_names_dice = [ 'All', 'ICH only', 'Non-ICH only' ] if not is_brain_exp else ['All', 'brain only', 'No_brain only'] serie_names_dice = ['Volume Dice', 'Slice Dice'] colors_dice = ['xkcd:pumpkin orange', 'xkcd:peach'] #, 'cornflowerblue'] metric_barplot(data_dice, serie_names=serie_names_dice, group_names=group_names_dice, colors=colors_dice, ax=ax_dice, fontsize=fontsize, jitter=True, jitter_color='gray', jitter_alpha=0.25, legend=True, legend_kwargs=dict(loc='upper left', ncol=1, frameon=False, framealpha=0.0, fontsize=fontsize, bbox_to_anchor=(0.0, -0.3), bbox_transform=ax_dice.transAxes), display_val=False, display_format='.2%', display_pos='top', tick_angle=20) ax_dice.set_ylim([0, 1]) ax_dice.set_ylabel('Dice [-]', fontsize=12) ax_dice.set_title('Dice Coefficients', fontsize=fontsize, fontweight='bold', loc='left') # Pred sample Highest Dice with ICH for k, (asc, is_ICH) in enumerate(zip([False, True, False, True], [1, 1, 0, 0])): # get n_samp highest/lowest dice for ICH samp_df = slice_df[slice_df.label == is_ICH].sort_values( by='Dice', ascending=asc).iloc[:n_samp, :] axs = [] for i, samp_row in enumerate(samp_df.iterrows()): samp_row = samp_row[1] ax_i = fig.add_subplot(gs[k + 3, i]) axs.append(ax_i) # load image and window it if not is_brain_exp: #slice_im = io.imread(os.path.join(data_path, f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}.tif')) try: slice_im = io.imread( os.path.join( data_path, f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}.tif' )) except FileNotFoundError: slice_im = io.imread( os.path.join( data_path, f'{samp_row.volID:03}/ct_scans/{samp_row.slice}.tif' )) else: slice_im = io.imread( os.path.join( data_path, f'{samp_row.volID:03}/ct/{samp_row.slice}.tif')) slice_im = window_ct(slice_im, win_center=cfg['data']['win_center'], win_width=cfg['data']['win_width'], out_range=(0, 1)) # load truth mask if is_ICH == 1: if not is_brain_exp: #slice_trg = io.imread(os.path.join(data_path, f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}_ICH_Seg.bmp')) try: slice_trg = io.imread( os.path.join( data_path, f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}_ICH_Seg.bmp' )) except FileNotFoundError: slice_trg = io.imread( os.path.join( data_path, f'{samp_row.volID:03}/masks/{samp_row.slice}_ICH.bmp' )) else: slice_trg = io.imread( os.path.join( data_path, f'{samp_row.volID:03}/mask/{samp_row.slice}_Seg.bmp' )) else: slice_trg = np.zeros_like(slice_im) slice_trg = slice_trg.astype('bool') # load prediction slice_pred = io.imread( os.path.join( exp_folder, f'Fold_{samp_row.Fold}/pred/{samp_row.volID}/{samp_row.slice}.bmp' )) slice_pred = skimage.transform.resize(slice_pred, slice_trg.shape, order=0) slice_pred = slice_pred.astype('bool') # plot all imshow_pred(slice_im, slice_pred, target=slice_trg, ax=ax_i, im_cmap='gray', pred_color='xkcd:vermillion', target_color='forestgreen', pred_alpha=0.7, target_alpha=1, legend=False, legend_kwargs=None) ax_i.text(0, 1.1, f' {samp_row.volID:03} / {samp_row.slice:02}', fontsize=10, fontweight='bold', color='white', ha='left', va='top', transform=ax_i.transAxes) pos = axs[0].get_position() pos3 = axs[1].get_position() pos4 = axs[-1].get_position() fig.patches.extend([ plt.Rectangle( (pos.x0 - 0.5 * pos.width, pos4.y0 - 0.1 * pos.height), 0.5 * pos.width + (pos3.x0 - pos.x0) * len(axs), 1.3 * pos.height, fc='black', ec='black', alpha=1, zorder=-1, transform=fig.transFigure, figure=fig) ]) axs[0].text( -0.25, 0.5, f"{'Low' if asc else 'High'}est Dice\n({'non-' if is_ICH == 0 else ''}{'ICH' if not is_brain_exp else 'brain'})", fontsize=10, fontweight='bold', ha='center', va='center', rotation=90, color='lightgray', transform=axs[0].transAxes) handles = [ matplotlib.patches.Patch(facecolor='forestgreen', alpha=1), matplotlib.patches.Patch(facecolor='xkcd:vermillion', alpha=0.7) ] labels = ['Ground Truth', 'Prediction'] axs[n_samp // 2].legend(handles, labels, loc='upper center', ncol=2, frameon=False, framealpha=0.0, fontsize=12, bbox_to_anchor=(0.5, 0.0), bbox_transform=axs[n_samp // 2].transAxes) # Save figure fig.savefig(save_fn, dpi=300, bbox_inches='tight')
def main(vol_fn, slice, pred_fn, trgt_fn, pred_color, trgt_color, win, cam_view, isoval, vol_alpha, overlap, save_fn): """ Provide an axial, sagital, coronal and 3D view of the Nifti volume at vol_fn. The view are cross sections given by the integer in slice ([axial, sagital, coronal]). If a prediction and/or target is provided, the mask is/are overlaid on top on the views. """ slice = ast.literal_eval(slice) win = ast.literal_eval(win) cam_view = ast.literal_eval(cam_view) if cam_view else None # load volume vol_nii = nib.load(vol_fn) aspect_ratio = vol_nii.header['pixdim'][3] / vol_nii.header['pixdim'][2] vol = np.rot90(vol_nii.get_fdata(), k=1, axes=(0, 1)) vol = window_ct(vol, win_center=win[0], win_width=win[1], out_range=(0, 1)) # load prediction if pred_fn: pred_nii = nib.load(pred_fn) pred = np.rot90(pred_nii.get_fdata(), k=1, axes=(0, 1)) # load prediction if trgt_fn: trgt_nii = nib.load(trgt_fn) trgt = np.rot90(trgt_nii.get_fdata(), k=1, axes=(0, 1)) # get 3D rendering data = pv.wrap(vol) data.spacing = vol_nii.header['pixdim'][1:4] surface = data.contour([isoval], ) if pred_fn: data_pred = pv.wrap(pred) data_pred.spacing = pred_nii.header['pixdim'][1:4] surface_pred = data_pred.contour([1], ) if trgt_fn: data_trgt = pv.wrap(trgt) data_trgt.spacing = trgt_nii.header['pixdim'][1:4] surface_trgt = data_trgt.contour([1], ) cpos = cam_view if not overlap and pred_fn is not None and trgt_fn is not None: # make 3D pred rendering p = pv.Plotter(off_screen=True, window_size=[512, 512]) p.background_color = 'black' p.add_mesh(surface, opacity=vol_alpha, clim=data.get_data_range(), color='lightgray') p.add_mesh(surface_pred, opacity=1, color=pred_color) if cpos: p.camera_position = cpos else: p.view_isometric() _, vol3Drender_pred = p.show(screenshot=True) # make 3D trgt rendering p = pv.Plotter(off_screen=True, window_size=[512, 512]) p.background_color = 'black' p.add_mesh(surface, opacity=vol_alpha, clim=data.get_data_range(), color='lightgray') p.add_mesh(surface_trgt, opacity=1, color=trgt_color) if cpos: p.camera_position = cpos else: p.view_isometric() _, vol3Drender_trgt = p.show(screenshot=True) # Make figure if pred_fn is None: pred = np.zeros_like(vol).astype(bool) if trgt_fn is None: trgt = np.zeros_like(vol).astype(bool) fig, axs = plt.subplots(2, 4, figsize=(10, 5)) # Axial imshow_pred(vol[:, :, slice[0]], pred[:, :, slice[0]].astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect='equal', interpolation='nearest'), legend=False, ax=axs[0, 0]) axs[0, 0].set_axis_off() axs[0, 0].set_title('Axial', color='white') imshow_pred(vol[:, :, slice[0]], np.zeros_like(vol)[:, :, slice[0]].astype(bool), trgt[:, :, slice[0]].astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect='equal', interpolation='nearest'), legend=False, ax=axs[1, 0]) axs[1, 0].set_axis_off() # Sagital legend, legend_kwargs = False, None if pred_fn is not None or trgt_fn is not None: legend = True legend_kwargs = dict(loc='upper center', ncol=2, frameon=False, labelcolor='white', framealpha=0.0, fontsize=10, bbox_to_anchor=(0.5, -0.2), bbox_transform=axs[1, 1].transAxes) imshow_pred(np.rot90(vol[:, slice[1], :], axes=(0, 1)), np.rot90(pred[:, slice[1], :], axes=(0, 1)).astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect=aspect_ratio, interpolation='nearest'), legend=False, ax=axs[0, 1]) axs[0, 1].set_axis_off() axs[0, 1].set_title('Sagital', color='white') imshow_pred(np.rot90(vol[:, slice[1], :], axes=(0, 1)), np.rot90(np.zeros_like(vol)[:, slice[1], :], axes=(0, 1)).astype(bool), np.rot90(trgt[:, slice[1], :], axes=(0, 1)).astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect=aspect_ratio, interpolation='nearest'), legend=legend, legend_kwargs=legend_kwargs, ax=axs[1, 1]) axs[1, 1].set_axis_off() # Coronal imshow_pred(np.rot90(vol[slice[2], :, :], axes=(0, 1)), np.rot90(pred[slice[2], :, :], axes=(0, 1)).astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect=aspect_ratio, interpolation='nearest'), legend=False, ax=axs[0, 2]) axs[0, 2].set_axis_off() axs[0, 2].set_title('Coronal', color='white') imshow_pred(np.rot90(vol[slice[2], :, :], axes=(0, 1)), np.rot90(np.zeros_like(vol)[slice[2], :, :], axes=(0, 1)).astype(bool), np.rot90(trgt[slice[2], :, :], axes=(0, 1)).astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect=aspect_ratio, interpolation='nearest'), legend=False, ax=axs[1, 2]) axs[1, 2].set_axis_off() # 3D rendering axs[0, 3].imshow(vol3Drender_pred, cmap='gray') axs[0, 3].set_axis_off() axs[0, 3].set_title('3D rendering', color='white') axs[1, 3].imshow(vol3Drender_trgt, cmap='gray') axs[1, 3].set_axis_off() # save figure fig.set_facecolor('black') fig.tight_layout() save_fn = save_fn if save_fn else f'A{slice[0]}_S{slice[1]}_C{slice[2]}.pdf' fig.savefig(save_fn, dpi=300, bbox_inches='tight') else: # make 3D rendering p = pv.Plotter(off_screen=True, window_size=[512, 512]) p.background_color = 'black' p.add_mesh(surface, opacity=vol_alpha, clim=data.get_data_range(), color='lightgray') if pred_fn: p.add_mesh(surface_pred, opacity=1, color=pred_color) if trgt_fn: p.add_mesh(surface_trgt, opacity=1, color=trgt_color) if cpos: p.camera_position = cpos else: p.view_isometric() _, vol3Drender = p.show(screenshot=True) # Make figure if pred_fn is None: pred = np.zeros_like(vol).astype(bool) if trgt_fn is None: trgt = np.zeros_like(vol).astype(bool) fig, axs = plt.subplots(1, 4, figsize=(10, 6)) # Axial imshow_pred(vol[:, :, slice[0]], pred[:, :, slice[0]].astype(bool), trgt[:, :, slice[0]].astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect='equal', interpolation='nearest'), legend=False, ax=axs[0]) axs[0].set_axis_off() axs[0].set_title('Axial', color='white') # Sagital legend, legend_kwargs = False, None if pred_fn is not None or trgt_fn is not None: legend = True if trgt_fn is not None and pred_fn is not None else False legend_kwargs = dict(loc='upper center', ncol=2, frameon=False, labelcolor='white', framealpha=0.0, fontsize=10, bbox_to_anchor=(0.5, -0.1), bbox_transform=axs[1].transAxes) imshow_pred(np.rot90(vol[:, slice[1], :], axes=(0, 1)), np.rot90(pred[:, slice[1], :], axes=(0, 1)).astype(bool), np.rot90(trgt[:, slice[1], :], axes=(0, 1)).astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect=aspect_ratio, interpolation='nearest'), legend=legend, legend_kwargs=legend_kwargs, ax=axs[1]) axs[1].set_axis_off() axs[1].set_title('Sagital', color='white') # Coronal imshow_pred(np.rot90(vol[slice[2], :, :], axes=(0, 1)), np.rot90(pred[slice[2], :, :], axes=(0, 1)).astype(bool), np.rot90(trgt[slice[2], :, :], axes=(0, 1)).astype(bool), im_cmap='gray', pred_color=pred_color, pred_alpha=0.8, target_color=trgt_color, target_alpha=0.8, imshow_kwargs=dict(aspect=aspect_ratio, interpolation='nearest'), legend=False, ax=axs[2]) axs[2].set_axis_off() axs[2].set_title('Coronal', color='white') # 3D rendering axs[3].imshow(vol3Drender, cmap='gray') axs[3].set_axis_off() axs[3].set_title('3D rendering', color='white') # save figure fig.set_facecolor('black') fig.tight_layout() save_fn = save_fn if save_fn else f'A{slice[0]}_S{slice[1]}_C{slice[2]}.pdf' fig.savefig(save_fn, dpi=300, bbox_inches='tight')