def __getitem__(self, idx):
        """
        Extract the CT and corresponding mask sepcified by idx.
        ----------
        INPUT
            |---- idx (int) the sample index in self.data_df.
        OUTPUT
            |---- slice (torch.tensor) the CT image with dimension (1 x H x W).
            |---- mask (torch.tensor) the segmentation mask with dimension (1 x H x W).
            |---- patient_nbr (torch.tensor) the patient id as a single value.
            |---- slice_nbr (torch.tensor) the slice number as a single value.
        """
        # load image
        slice = io.imread(self.data_path + self.data_df.iloc[idx].CT_fn)
        if self.window:
            slice = window_ct(slice,
                              win_center=self.window[0],
                              win_width=self.window[1],
                              out_range=(0, 1))
        # load mask if one, else make a blank array
        if self.data_df.iloc[idx].mask_fn == 'None':
            mask = np.zeros_like(slice)
        else:
            mask = io.imread(self.data_path + self.data_df.iloc[idx].mask_fn)
        # get the patient id
        patient_nbr = torch.tensor(self.data_df.iloc[idx].PatientNumber)
        # get slice number
        slice_nbr = torch.tensor(self.data_df.iloc[idx].SliceNumber)

        # Apply the transform : Data Augmentation + image formating
        slice, mask = self.transform(slice, mask)

        return slice, mask, patient_nbr, slice_nbr
    def __getitem__(self, idx):
        """
        Extract the CT sepcified by idx.
        ----------
        INPUT
            |---- idx (int) the sample index in self.data_df.
        OUTPUT
            |---- im (torch.tensor) the CT image with dimension (1 x H x W).
            |---- mask (torch.tensor) the inpaining mask with dimension (1 x H x W).
        """
        # load dicom and recover the CT pixel values
        dcm_im = pydicom.dcmread(self.data_path +
                                 self.data_df.iloc[idx].filename)
        im = (dcm_im.pixel_array * float(dcm_im.RescaleSlope) +
              float(dcm_im.RescaleIntercept))
        # Window the CT-scan
        if self.window:
            im = window_ct(im,
                           win_center=self.window[0],
                           win_width=self.window[1],
                           out_range=(0, 1))
        # transform image
        im = self.transform(im)
        label = self.data_df.iloc[idx].Hemorrhage

        if self.artificial_anomaly and (np.random.rand() <
                                        self.anomaly_proba) and (label == 0):
            # get a mask
            anomalies = tf.ToTorchTensor()(self.draw_ellipses(
                (im.shape[1], im.shape[2]), **self.drawing_params))
            im = torch.where(anomalies > 0, anomalies, im)
            label = 1

        return im, torch.tensor(label), torch.tensor(idx)
    def __getitem__(self, idx):
        """
        Extract the CT sepcified by idx.
        ----------
        INPUT
            |---- idx (int) the sample index in self.data_df.
        OUTPUT
            |---- im (torch.tensor) the CT image with dimension (1 x H x W).
            |---- mask (torch.tensor) the inpaining mask with dimension (1 x H x W).
        """
        # load dicom and recover the CT pixel values
        dcm_im = pydicom.dcmread(self.data_path +
                                 self.data_df.iloc[idx].filename)
        im = (dcm_im.pixel_array * float(dcm_im.RescaleSlope) +
              float(dcm_im.RescaleIntercept))
        # Window the CT-scan
        if self.window:
            im = window_ct(im,
                           win_center=self.window[0],
                           win_width=self.window[1],
                           out_range=(0, 1))
        # transform image
        im = self.transform(im)
        # get a mask
        mask = self.random_ff_mask((im.shape[1], im.shape[2]))

        return im, tf.ToTorchTensor()(mask)
    def __getitem__(self, idx):
        """
        Extract the image and mask sepcified by idx.
        ----------
        INPUT
            |---- idx (int) the sample index in self.data_df.
        OUTPUT
            |---- im (torch.tensor) the image with dimension (1 x H x W).
            |---- mask (torch.tensor) the mask with dimension (1 x H x W).
            |---- idx (torch.tensor) the data index in the data_df.
        """
        # load dicom and recover the CT pixel values
        im = io.imread(
            os.path.join(self.data_path, self.data_df.iloc[idx].im_fn))
        mask = io.imread(
            os.path.join(self.data_path, self.data_df.iloc[idx].mask_fn))
        # Window the CT-scan
        if self.window:
            im = window_ct(im,
                           win_center=self.window[0],
                           win_width=self.window[1],
                           out_range=(0, 1))
        # transform image
        im, mask = self.transform(im, mask)

        return im, mask, torch.tensor(idx)
def main(input_data_path, output_data_path, window):
    """
    Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap for the brain extraction.
    """
    # open data info dataframe
    info_df = pd.read_csv(os.path.join(input_data_path, 'info.csv'), index_col=0)
    # make patient directory
    if not os.path.exists(output_data_path): os.mkdir(output_data_path)
    # iterate over volume to extract data
    output_info = []
    for n, id in enumerate(info_df.id.values):
        # read nii volume
        ct_nii = nib.load(os.path.join(input_data_path, f'ct_scans/{id}.nii'))
        mask_nii = nib.load(os.path.join(input_data_path, f'masks/{id}.nii.gz'))
        # get np.array
        ct_vol = ct_nii.get_fdata()
        mask_vol = skimage.img_as_bool(mask_nii.get_fdata())
        # rotate 90° counter clockwise for head pointing upward
        ct_vol = np.rot90(ct_vol, axes=(0,1))
        mask_vol = np.rot90(mask_vol, axes=(0,1))
        # window the ct volume to get better contrast of soft tissues
        if window is not None:
            ct_vol = window_ct(ct_vol, win_center=window[0], win_width=window[1], out_range=(0,1))

        if mask_vol.shape != ct_vol.shape:
            print(f'>>> Warning! The ct volume of patient {id} does not have '
                  f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})')
        # make patient directory
        if not os.path.exists(os.path.join(output_data_path, f'{id:03}/ct/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/ct/'))
        if not os.path.exists(os.path.join(output_data_path, f'{id:03}/mask/')): os.makedirs(os.path.join(output_data_path, f'{id:03}/mask/'))
        # iterate over slices to save slices
        for i, slice in enumerate(range(ct_vol.shape[2])):
            ct_slice_fn =f'{id:03}/ct/{slice+1}.tif'
            # save CT slice
            skimage.io.imsave(os.path.join(output_data_path, ct_slice_fn), ct_vol[:,:,slice], check_contrast=False)
            is_low = True if skimage.exposure.is_low_contrast(ct_vol[:,:,slice]) else False
            # save mask if some brain on slice
            if np.any(mask_vol[:,:,slice]):
                mask_slice_fn = f'{id:03}/mask/{slice+1}_Seg.bmp'
                skimage.io.imsave(os.path.join(output_data_path, mask_slice_fn), skimage.img_as_ubyte(mask_vol[:,:,slice]), check_contrast=False)
            else:
                mask_slice_fn = 'None'
            # add info to output list
            output_info.append({'volume':id, 'slice':slice+1, 'ct_fn':ct_slice_fn, 'mask_fn':mask_slice_fn, 'low_contrast_ct':is_low})

            print_progessbar(i, ct_vol.shape[2], Name=f'Volume {id:03} {n+1:03}/{len(info_df.id):03}',
                             Size=20, erase=False)

    # Make dataframe of outputs
    output_info_df = pd.DataFrame(output_info)
    # save df
    output_info_df.to_csv(os.path.join(output_data_path, 'slice_info.csv'))
    print('>>> Slice informations saved at ' + os.path.join(output_data_path, 'slice_info.csv'))
    # save patient df
    info_df.to_csv(os.path.join(output_data_path, 'volume_info.csv'))
    print('>>> Volume informations saved at ' + os.path.join(output_data_path, 'volume_info.csv'))
    def __getitem__(self, idx):
        """
        Extract the CT sepcified by idx.
        ----------
        INPUT
            |---- idx (int) the sample index in self.data_df.
        OUTPUT
            |---- im (torch.tensor) the CT image with dimension (1 x H x W).
            |---- lab (torch.tensor) the label for hemorrhage presence (0 or 1).
            |---- idx (torch.tensor) the sample idx.
        """
        # load dicom and recover the CT pixel values
        dcm_im = pydicom.dcmread(self.data_path +
                                 self.data_df.iloc[idx].filename)
        im = (dcm_im.pixel_array * float(dcm_im.RescaleSlope) +
              float(dcm_im.RescaleIntercept))
        # Window the CT-scan
        if self.window:
            im = window_ct(im,
                           win_center=self.window[0],
                           win_width=self.window[1],
                           out_range=(0, 1))

        if self.mode == 'standard':
            # transform image
            im = self.transform(im)
            return self.toTensor(im), torch.tensor(
                idx)  #torch.tensor(lab), torch.tensor(idx)
        elif self.mode == 'context_restoration':
            # generate corrupeted version
            # transform image
            im = self.transform(im)
            swapped_im = self.swap_tranform(im)
            return self.toTensor(im), self.toTensor(swapped_im), torch.tensor(
                idx)
        elif self.mode == 'contrastive':
            # augmente image twice
            im1 = self.contrastive_transform(self.transform(im))
            im2 = self.contrastive_transform(self.transform(im))
            return self.toTensor(im1), self.toTensor(im2), torch.tensor(idx)
        elif self.mode == 'binary_classification':
            im = self.transform(im)
            label = self.data_df.iloc[idx].Hemorrhage
            return self.toTensor(im), torch.tensor(label), torch.tensor(idx)
        elif self.mode == 'multi_classification':
            im = self.transform(im)
            samp = self.data_df.iloc[idx]
            label = [samp[name] for name in self.class_name]
            return self.toTensor(im), torch.tensor(label), torch.tensor(idx)
    def __getitem__(self, idx):
        """
        Extract the stacked CT and attention map, the corresponding ground truth mask, volume id, and slice number sepcified by idx.
        ----------
        INPUT
            |---- idx (int) the sample index in self.data_df.
        OUTPUT
            |---- input (torch.tensor) the CT image stacked with the attention map with dimension (2 x H x W).
            |---- mask (torch.tensor) the segmentation mask with dimension (1 x H x W).
            |---- patient_nbr (torch.tensor) the patient id as a single value.
            |---- slice_nbr (torch.tensor) the slice number as a single value.
        """
        # load image
        slice = io.imread(self.data_path + self.data_df.iloc[idx].ct_fn)
        if self.window:
            slice = window_ct(slice,
                              win_center=self.window[0],
                              win_width=self.window[1],
                              out_range=(0, 1))
        # load attention map and stack it with the slice
        if self.data_df.iloc[idx].attention_fn == 'None':
            attention_map = np.zeros_like(slice)
        else:
            attention_map = skimage.img_as_float(
                io.imread(self.data_path +
                          self.data_df.iloc[idx].attention_fn))
            attention_map = skimage.transform.resize(attention_map,
                                                     slice.shape[:2],
                                                     order=1,
                                                     preserve_range=True)
        input = np.stack([slice, attention_map], axis=2)

        # load mask if one, else make a blank array
        if self.data_df.iloc[idx].mask_fn == 'None':
            mask = np.zeros_like(slice)
        else:
            mask = io.imread(self.data_path + self.data_df.iloc[idx].mask_fn)
        # get the patient id
        patient_nbr = torch.tensor(self.data_df.iloc[idx].id)
        # get slice number
        slice_nbr = torch.tensor(self.data_df.iloc[idx].slice)

        # Apply the transform : Data Augmentation + image formating
        input, mask = self.transform(input, mask)

        return input, mask, patient_nbr, slice_nbr
    def __getitem__(self, idx):
        """
        Get the CT-volumes of the given patient idx.
        ----------
        INPUT
            |---- idx (int) the patient index in self.PatientID_list to extract.
        OUTPUT
            |---- volume (torch.Tensor) the CT-volume in a tensor (H x W x Slice)
        """
        # load data
        ct_nii = nib.load(self.data_path + self.data_df.loc[idx, 'CT_fn'])
        mask_nii = nib.load(self.data_path + self.data_df.loc[idx, 'mask_fn'])
        pID = torch.tensor(self.data_df.loc[idx, 'PatientNumber'])
        # get volumes and pixel dimension
        ct_vol = np.rot90(ct_nii.get_fdata(), axes=(0, 1))
        mask = np.rot90(mask_nii.get_fdata(), axes=(0, 1))
        pix_dim = ct_nii.header['pixdim'][
            1:4]  # recover pixel physical dimension

        # window CT-scan for soft tissus
        ct_vol = window_ct(ct_vol,
                           win_center=self.window[0],
                           win_width=self.window[1],
                           out_range=(0, 1))
        # resample vol and mask
        ct_vol = resample_ct(ct_vol,
                             pix_dim,
                             out_pixel_dim=self.resampling_dim,
                             preserve_range=True,
                             order=self.resampling_order)
        mask = resample_ct(mask,
                           pix_dim,
                           out_pixel_dim=self.resampling_dim,
                           preserve_range=True,
                           order=0)  #self.resampling_order)

        ct_vol, mask = self.transform(ct_vol, mask)

        return ct_vol, mask.bool(), pID
Example #9
0
    def segement_volume(self,
                        vol,
                        save_fn=None,
                        window=None,
                        input_size=(256, 256),
                        return_pred=False):
        """
        Segement each slice of the passed Nifti volume and save the results as a Nifti volume.
        ----------
        INPUT
            |---- vol (nibabel.nifti1.Nifti1Pair) the nibabel volume with metadata to segement.
            |---- save_fn (str) where to save the segmentation.
            |---- window (tuple (center, width)) the winowing to apply to the ct-scan.
            |---- input_size (tuple (h, w)) the input size for the network.
            |---- return_pred (bool) whether to return the volume of prediction.
        OUTPUT
            |---- (mask_vol) (nibabel.nifti1.Nifti1Pair) the prediction volume.
        """
        pred_list = []
        vol_data = np.rot90(vol.get_fdata(),
                            axes=(0, 1))  # 90° counterclockwise rotation
        if window:
            vol_data = window_ct(vol_data,
                                 win_center=window[0],
                                 win_width=window[1],
                                 out_range=(0, 1))
        transform = tf.Compose(tf.Resize(H=input_size[0], W=input_size[1]),
                               tf.ToTorchTensor())
        self.unet.eval()
        self.unet.to(self.device)
        with torch.no_grad():
            for s in range(0, vol_data.shape[2], self.batch_size):
                # get slice in good size and as tensor
                input = transform(vol_data[:, :, s:s + self.batch_size]).to(
                    self.device).float().permute(3, 0, 1, 2)
                # predict
                pred = self.unet(input)
                pred = torch.where(pred >= 0.5,
                                   torch.ones_like(pred, device=self.device),
                                   torch.zeros_like(pred, device=self.device))
                # store pred (B x H x W)
                pred_list.append(
                    pred.squeeze(dim=1).permute(1, 2, 0).cpu().numpy().astype(
                        np.uint8) * 255)
                if self.print_progress:
                    print_progessbar(s + pred.shape[0] - 1,
                                     Max=vol_data.shape[2],
                                     Name='Slice',
                                     Size=20,
                                     erase=True)

        # make the prediction volume
        vol_pred = np.concatenate(pred_list, axis=2)
        # resize to input size and rotate 90° clockwise
        vol_pred = np.rot90(skimage.transform.resize(
            vol_pred, (vol.header['dim'][1], vol.header['dim'][2]), order=0),
                            axes=(1, 0))
        # make Nifty and save it
        vol_pred_nii = nib.Nifti1Pair(vol_pred.astype(np.uint8), vol.affine)
        if save_fn:
            nib.save(vol_pred_nii, save_fn)
        # return Nifti prediction
        if return_pred:
            return vol_pred_nii
Example #10
0
def main(input_data_path, output_data_path, window):
    """
    Convert the Volumetric CT data and mask (in NIfTI format) to a dataset of 2D images in tif and masks in bitmap.
    """
    # open data info dataframe
    info_df = pd.read_csv(input_data_path + 'hemorrhage_diagnosis_raw_ct.csv')
    # replace No-Hemorrhage to hemorrange
    info_df['Hemorrhage'] = 1 - info_df.No_Hemorrhage
    info_df.drop(columns='No_Hemorrhage', inplace=True)
    # open patient info dataframe
    patient_df = pd.read_csv(input_data_path + 'Patient_demographics.csv', header=1, skipfooter=2, engine='python') \
                   .rename(columns={'Unnamed: 0':'PatientNumber', 'Unnamed: 1':'Age',
                                    'Unnamed: 2':'Gender', 'Unnamed: 8':'Fracture', 'Unnamed: 9':'Note'})
    patient_df[patient_df.columns[3:9]] = patient_df[
        patient_df.columns[3:9]].fillna(0).astype(int)
    # add columns Hemorrgae (any ICH)
    patient_df['Hemorrhage'] = patient_df[patient_df.columns[3:8]].max(axis=1)

    # make patient directory
    if not os.path.exists(output_data_path): os.mkdir(output_data_path)
    if not os.path.exists(output_data_path + 'Patient_CT/'):
        os.mkdir(output_data_path + 'Patient_CT/')
    # iterate over volume to extract data
    output_info = []
    for n, id in enumerate(info_df.PatientNumber.unique()):
        # read nii volume
        ct_nii = nib.load(input_data_path + f'ct_scans/{id:03}.nii')
        mask_nii = nib.load(input_data_path + f'masks/{id:03}.nii')
        # get np.array
        ct_vol = ct_nii.get_fdata()
        mask_vol = skimage.img_as_bool(mask_nii.get_fdata())
        # rotate 90° counter clockwise for head pointing upward
        ct_vol = np.rot90(ct_vol, axes=(0, 1))
        mask_vol = np.rot90(mask_vol, axes=(0, 1))
        # window the ct volume to get better contrast of soft tissues
        if window is not None:
            ct_vol = window_ct(ct_vol,
                               win_center=window[0],
                               win_width=window[1],
                               out_range=(0, 1))

        if mask_vol.shape != ct_vol.shape:
            print(
                f'>>> Warning! The ct volume of patient {id} does not have '
                f'the same dimension as the ground truth. CT ({ct_vol.shape}) vs Mask ({mask_vol.shape})'
            )
        # make patient directory
        if not os.path.exists(output_data_path + f'Patient_CT/{id:03}/'):
            os.mkdir(output_data_path + f'Patient_CT/{id:03}/')
        # iterate over slices to save slices
        for i, slice in enumerate(range(ct_vol.shape[2])):
            ct_slice_fn = f'Patient_CT/{id:03}/{slice+1}.tif'
            # save CT slice
            skimage.io.imsave(output_data_path + ct_slice_fn,
                              ct_vol[:, :, slice],
                              check_contrast=False)
            is_low = True if skimage.exposure.is_low_contrast(
                ct_vol[:, :, slice]) else False
            # save mask if some positive ICH
            if np.any(mask_vol[:, :, slice]):
                mask_slice_fn = f'Patient_CT/{id:03}/{slice+1}_ICH_Seg.bmp'
                skimage.io.imsave(output_data_path + mask_slice_fn,
                                  skimage.img_as_ubyte(mask_vol[:, :, slice]),
                                  check_contrast=False)
            else:
                mask_slice_fn = 'None'
            # add info to output list
            output_info.append({
                'PatientNumber': id,
                'SliceNumber': slice + 1,
                'CT_fn': ct_slice_fn,
                'mask_fn': mask_slice_fn,
                'low_contrast_CT': is_low
            })

            print_progessbar(
                i,
                ct_vol.shape[2],
                Name=
                f'Patient {id:03} {n+1:03}/{len(info_df.PatientNumber.unique()):03}',
                Size=20,
                erase=False)

    # Make dataframe of outputs
    output_info_df = pd.DataFrame(output_info)
    # Merge with input info
    info_df = pd.merge(info_df,
                       output_info_df,
                       how='inner',
                       on=['PatientNumber', 'SliceNumber'])
    # save df
    info_df.to_csv(output_data_path + 'ct_info.csv')
    print('>>> Data informations saved at ' + output_data_path + 'ct_info.csv')
    # save patient df
    patient_df.to_csv(output_data_path + 'patient_info.csv')
    print('>>> Patient informations saved at ' + output_data_path +
          'patient_info.csv')
def analyse_supervised_exp(exp_folder,
                           data_path,
                           n_fold,
                           config_folder=None,
                           save_fn='results_overview.pdf',
                           is_brain_exp=False):
    """
    Generate a summary figure of the supervised ICH segmentation experiment.
    """
    if config_folder is None:
        config_folder = exp_folder
    # utility function
    def h_padcat(arr1, arr2):
        out_len = max([arr1.shape[0], arr2.shape[0]])
        arr1_pad = np.pad(arr1, ((0, out_len - arr1.shape[0]), (0, 0)),
                          constant_values=np.nan)
        arr2_pad = np.pad(arr2, ((0, out_len - arr2.shape[0]), (0, 0)),
                          constant_values=np.nan)
        return np.concatenate([arr1_pad, arr2_pad], axis=1)

    ########## get data
    # get losses
    loss_list = []
    for train_stat_fn in glob.glob(
            os.path.join(exp_folder, 'Fold_*/outputs.json')):
        with open(train_stat_fn, 'r') as fn:
            loss_list.append(np.array(json.load(fn)['train']['evolution']))

    all = np.stack(loss_list, axis=2)[:, [1, 2, 3], :]
    data_evo = [
        np.concatenate([
            np.expand_dims(np.arange(1, all.shape[0] + 1), axis=1), all[:,
                                                                        i, :]
        ],
                       axis=1) for i in range(all.shape[1])
    ]

    # load performances
    results_df = pd.read_csv(
        os.path.join(exp_folder, 'all_volume_prediction.csv'))
    results_df = results_df.drop(columns=results_df.columns[0])

    # load performances at slice level
    with open(os.path.join(config_folder, 'config.json'), 'r') as f:
        cfg = json.load(f)
    df_list = []
    for i in range(n_fold):
        df_tmp = pd.read_csv(
            os.path.join(exp_folder,
                         f'Fold_{i+1}/pred/slice_prediction_scores.csv'))
        df_tmp['Fold'] = i + 1
        df_list.append(df_tmp)
    slice_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    slice_df = slice_df.drop(columns=slice_df.columns[0])

    ########## PLOT
    fontsize = 12
    n_samp = 10
    fig = plt.figure(figsize=(15, 12))
    gs = fig.add_gridspec(
        nrows=7,
        ncols=n_samp,
        height_ratios=[0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15],
        hspace=0.4)
    # Loss Evolution Plot
    ax_evo = fig.add_subplot(gs[:2, :4])
    colors_evo = ['black', 'tomato', 'dodgerblue'
                  ] if not is_brain_exp else ['black', 'tomato', 'tomato']
    serie_names_evo = [
        'Train Loss', 'Dice (all)', 'Dice (ICH)'
    ] if not is_brain_exp else ['All', 'Dice (all)', 'Dice (brain)']
    if len(loss_list) > 0:
        curve_std(data_evo,
                  serie_names_evo,
                  colors=colors_evo,
                  ax=ax_evo,
                  lw=1,
                  CI_alpha=0.05,
                  rep_alpha=0.25,
                  plot_rep=True,
                  plot_mean=True,
                  plot_CI=True,
                  legend=True,
                  legend_kwargs=dict(loc='upper left',
                                     ncol=3,
                                     frameon=False,
                                     framealpha=0.0,
                                     fontsize=fontsize,
                                     bbox_to_anchor=(0.0, -0.3),
                                     bbox_transform=ax_evo.transAxes))
    ax_evo.set_xlabel('Epoch [-]', fontsize=fontsize)
    ax_evo.set_ylabel('Dice Loss [-] ; Dice Coeff. [-]', fontsize=fontsize)
    ax_evo.set_title('Training evolution',
                     fontsize=fontsize,
                     fontweight='bold',
                     loc='left')
    ax_evo.set_xlim([1, data_evo[0].shape[0]])
    ax_evo.set_ylim([0, 1])
    ax_evo.tick_params(axis='both', labelsize=fontsize)
    ax_evo.spines['top'].set_visible(False)
    ax_evo.spines['right'].set_visible(False)

    # Conf Mat BarPlot
    ax_cm = fig.add_subplot(gs[1, 5:7])
    ax_cm_bis = fig.add_subplot(gs[0, 5:7])
    # make data
    data_cm = [
        results_df[['TP', 'TN', 'FP', 'FN']].values,
        results_df.loc[results_df.label == 1, ['TP', 'TN', 'FP', 'FN']].values,
        results_df.loc[results_df.label == 0, ['TP', 'TN', 'FP', 'FN']].values
    ]
    serie_names_cm = [
        'All', 'ICH only', 'Non-ICH only'
    ] if not is_brain_exp else ['All', 'brain only', 'No_brain only']
    group_names_cm = ['TP', 'TN', 'FP', 'FN']
    colors_cm = ['tomato', 'dodgerblue', 'cornflowerblue']
    metric_barplot(data_cm,
                   serie_names=serie_names_cm,
                   group_names=group_names_cm,
                   colors=colors_cm,
                   ax=ax_cm,
                   fontsize=fontsize,
                   jitter=True,
                   jitter_color='gray',
                   jitter_alpha=0.25,
                   legend=True,
                   legend_kwargs=dict(loc='upper left',
                                      ncol=1,
                                      frameon=False,
                                      framealpha=0.0,
                                      fontsize=fontsize,
                                      bbox_to_anchor=(0.0, -0.4),
                                      bbox_transform=ax_cm.transAxes),
                   display_val=False)
    ax_cm.set_ylim([
        0,
        (results_df[['TP', 'FP', 'FN']].values.mean(axis=0) +
         2.2 * results_df[['TP', 'FP', 'FN']].values.std(axis=0)).max()
    ])
    ax_cm.set_ylabel('Count [-]', fontsize=fontsize)
    ax_cm.yaxis.set_major_formatter(
        matplotlib.ticker.FormatStrFormatter('%.1e'))
    # duplicate plot for long tail bar of True negative
    metric_barplot(data_cm,
                   serie_names=serie_names_cm,
                   group_names=group_names_cm,
                   colors=colors_cm,
                   ax=ax_cm_bis,
                   fontsize=fontsize,
                   jitter=True,
                   jitter_color='gray',
                   jitter_alpha=0.25,
                   legend=False,
                   legend_kwargs=None,
                   display_val=False)
    ax_cm_bis.set_ylim(bottom=results_df['TN'].values.mean() -
                       2.2 * results_df['TN'].values.std())
    ax_cm_bis.set_title('Volumetric Classification',
                        fontsize=fontsize,
                        fontweight='bold',
                        loc='left')
    ax_cm_bis.spines['bottom'].set_visible(False)
    ax_cm_bis.yaxis.set_major_formatter(
        matplotlib.ticker.FormatStrFormatter('%.1e'))
    ax_cm_bis.set_xticklabels([])
    ax_cm_bis.set_xticks([])
    d = .015
    ax_cm.plot((-d, +d), (1 - d, 1 + d),
               transform=ax_cm.transAxes,
               color='k',
               clip_on=False)
    ax_cm_bis.plot((-d, +d), (-d, +d),
                   transform=ax_cm_bis.transAxes,
                   color='k',
                   clip_on=False)

    # Dice barplot
    ax_dice = fig.add_subplot(gs[0:2, 8:])
    data_dice = [
        h_padcat(
            h_padcat(results_df[['Dice']].values,
                     results_df.loc[results_df.label == 1, ['Dice']].values),
            results_df.loc[results_df.label == 0, ['Dice']].values),
        h_padcat(
            h_padcat(slice_df[['Dice']].values,
                     slice_df.loc[slice_df.label == 1, ['Dice']].values),
            slice_df.loc[slice_df.label == 0, ['Dice']].values)
    ]

    group_names_dice = [
        'All', 'ICH only', 'Non-ICH only'
    ] if not is_brain_exp else ['All', 'brain only', 'No_brain only']
    serie_names_dice = ['Volume Dice', 'Slice Dice']
    colors_dice = ['xkcd:pumpkin orange', 'xkcd:peach']  #, 'cornflowerblue']
    metric_barplot(data_dice,
                   serie_names=serie_names_dice,
                   group_names=group_names_dice,
                   colors=colors_dice,
                   ax=ax_dice,
                   fontsize=fontsize,
                   jitter=True,
                   jitter_color='gray',
                   jitter_alpha=0.25,
                   legend=True,
                   legend_kwargs=dict(loc='upper left',
                                      ncol=1,
                                      frameon=False,
                                      framealpha=0.0,
                                      fontsize=fontsize,
                                      bbox_to_anchor=(0.0, -0.3),
                                      bbox_transform=ax_dice.transAxes),
                   display_val=False,
                   display_format='.2%',
                   display_pos='top',
                   tick_angle=20)
    ax_dice.set_ylim([0, 1])
    ax_dice.set_ylabel('Dice [-]', fontsize=12)
    ax_dice.set_title('Dice Coefficients',
                      fontsize=fontsize,
                      fontweight='bold',
                      loc='left')

    # Pred sample Highest Dice with ICH
    for k, (asc,
            is_ICH) in enumerate(zip([False, True, False, True],
                                     [1, 1, 0, 0])):
        # get n_samp highest/lowest dice for ICH
        samp_df = slice_df[slice_df.label == is_ICH].sort_values(
            by='Dice', ascending=asc).iloc[:n_samp, :]
        axs = []
        for i, samp_row in enumerate(samp_df.iterrows()):
            samp_row = samp_row[1]
            ax_i = fig.add_subplot(gs[k + 3, i])
            axs.append(ax_i)
            # load image and window it
            if not is_brain_exp:
                #slice_im = io.imread(os.path.join(data_path, f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}.tif'))
                try:
                    slice_im = io.imread(
                        os.path.join(
                            data_path,
                            f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}.tif'
                        ))
                except FileNotFoundError:
                    slice_im = io.imread(
                        os.path.join(
                            data_path,
                            f'{samp_row.volID:03}/ct_scans/{samp_row.slice}.tif'
                        ))
            else:
                slice_im = io.imread(
                    os.path.join(
                        data_path,
                        f'{samp_row.volID:03}/ct/{samp_row.slice}.tif'))
            slice_im = window_ct(slice_im,
                                 win_center=cfg['data']['win_center'],
                                 win_width=cfg['data']['win_width'],
                                 out_range=(0, 1))
            # load truth mask
            if is_ICH == 1:
                if not is_brain_exp:
                    #slice_trg = io.imread(os.path.join(data_path, f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}_ICH_Seg.bmp'))
                    try:
                        slice_trg = io.imread(
                            os.path.join(
                                data_path,
                                f'Patient_CT/{samp_row.volID:03}/{samp_row.slice}_ICH_Seg.bmp'
                            ))
                    except FileNotFoundError:
                        slice_trg = io.imread(
                            os.path.join(
                                data_path,
                                f'{samp_row.volID:03}/masks/{samp_row.slice}_ICH.bmp'
                            ))
                else:
                    slice_trg = io.imread(
                        os.path.join(
                            data_path,
                            f'{samp_row.volID:03}/mask/{samp_row.slice}_Seg.bmp'
                        ))
            else:
                slice_trg = np.zeros_like(slice_im)
            slice_trg = slice_trg.astype('bool')
            # load prediction
            slice_pred = io.imread(
                os.path.join(
                    exp_folder,
                    f'Fold_{samp_row.Fold}/pred/{samp_row.volID}/{samp_row.slice}.bmp'
                ))
            slice_pred = skimage.transform.resize(slice_pred,
                                                  slice_trg.shape,
                                                  order=0)
            slice_pred = slice_pred.astype('bool')
            # plot all
            imshow_pred(slice_im,
                        slice_pred,
                        target=slice_trg,
                        ax=ax_i,
                        im_cmap='gray',
                        pred_color='xkcd:vermillion',
                        target_color='forestgreen',
                        pred_alpha=0.7,
                        target_alpha=1,
                        legend=False,
                        legend_kwargs=None)
            ax_i.text(0,
                      1.1,
                      f' {samp_row.volID:03} / {samp_row.slice:02}',
                      fontsize=10,
                      fontweight='bold',
                      color='white',
                      ha='left',
                      va='top',
                      transform=ax_i.transAxes)

        pos = axs[0].get_position()
        pos3 = axs[1].get_position()
        pos4 = axs[-1].get_position()
        fig.patches.extend([
            plt.Rectangle(
                (pos.x0 - 0.5 * pos.width, pos4.y0 - 0.1 * pos.height),
                0.5 * pos.width + (pos3.x0 - pos.x0) * len(axs),
                1.3 * pos.height,
                fc='black',
                ec='black',
                alpha=1,
                zorder=-1,
                transform=fig.transFigure,
                figure=fig)
        ])
        axs[0].text(
            -0.25,
            0.5,
            f"{'Low' if asc else 'High'}est Dice\n({'non-' if is_ICH == 0 else ''}{'ICH' if not is_brain_exp else 'brain'})",
            fontsize=10,
            fontweight='bold',
            ha='center',
            va='center',
            rotation=90,
            color='lightgray',
            transform=axs[0].transAxes)

    handles = [
        matplotlib.patches.Patch(facecolor='forestgreen', alpha=1),
        matplotlib.patches.Patch(facecolor='xkcd:vermillion', alpha=0.7)
    ]
    labels = ['Ground Truth', 'Prediction']
    axs[n_samp // 2].legend(handles,
                            labels,
                            loc='upper center',
                            ncol=2,
                            frameon=False,
                            framealpha=0.0,
                            fontsize=12,
                            bbox_to_anchor=(0.5, 0.0),
                            bbox_transform=axs[n_samp // 2].transAxes)

    # Save figure
    fig.savefig(save_fn, dpi=300, bbox_inches='tight')
Example #12
0
def main(vol_fn, slice, pred_fn, trgt_fn, pred_color, trgt_color, win,
         cam_view, isoval, vol_alpha, overlap, save_fn):
    """
    Provide an axial, sagital, coronal and 3D view of the Nifti volume at vol_fn. The view are cross sections given by
    the integer in slice ([axial, sagital, coronal]). If a prediction and/or target is provided, the mask is/are overlaid
    on top on the views.
    """
    slice = ast.literal_eval(slice)
    win = ast.literal_eval(win)
    cam_view = ast.literal_eval(cam_view) if cam_view else None
    # load volume
    vol_nii = nib.load(vol_fn)
    aspect_ratio = vol_nii.header['pixdim'][3] / vol_nii.header['pixdim'][2]
    vol = np.rot90(vol_nii.get_fdata(), k=1, axes=(0, 1))
    vol = window_ct(vol, win_center=win[0], win_width=win[1], out_range=(0, 1))
    # load prediction
    if pred_fn:
        pred_nii = nib.load(pred_fn)
        pred = np.rot90(pred_nii.get_fdata(), k=1, axes=(0, 1))

    # load prediction
    if trgt_fn:
        trgt_nii = nib.load(trgt_fn)
        trgt = np.rot90(trgt_nii.get_fdata(), k=1, axes=(0, 1))

    # get 3D rendering
    data = pv.wrap(vol)
    data.spacing = vol_nii.header['pixdim'][1:4]
    surface = data.contour([isoval], )
    if pred_fn:
        data_pred = pv.wrap(pred)
        data_pred.spacing = pred_nii.header['pixdim'][1:4]
        surface_pred = data_pred.contour([1], )
    if trgt_fn:
        data_trgt = pv.wrap(trgt)
        data_trgt.spacing = trgt_nii.header['pixdim'][1:4]
        surface_trgt = data_trgt.contour([1], )
    cpos = cam_view

    if not overlap and pred_fn is not None and trgt_fn is not None:
        # make 3D pred rendering
        p = pv.Plotter(off_screen=True, window_size=[512, 512])
        p.background_color = 'black'
        p.add_mesh(surface,
                   opacity=vol_alpha,
                   clim=data.get_data_range(),
                   color='lightgray')
        p.add_mesh(surface_pred, opacity=1, color=pred_color)
        if cpos:
            p.camera_position = cpos
        else:
            p.view_isometric()
        _, vol3Drender_pred = p.show(screenshot=True)
        # make 3D trgt rendering
        p = pv.Plotter(off_screen=True, window_size=[512, 512])
        p.background_color = 'black'
        p.add_mesh(surface,
                   opacity=vol_alpha,
                   clim=data.get_data_range(),
                   color='lightgray')
        p.add_mesh(surface_trgt, opacity=1, color=trgt_color)
        if cpos:
            p.camera_position = cpos
        else:
            p.view_isometric()
        _, vol3Drender_trgt = p.show(screenshot=True)

        # Make figure
        if pred_fn is None:
            pred = np.zeros_like(vol).astype(bool)
        if trgt_fn is None:
            trgt = np.zeros_like(vol).astype(bool)

        fig, axs = plt.subplots(2, 4, figsize=(10, 5))
        # Axial
        imshow_pred(vol[:, :, slice[0]],
                    pred[:, :, slice[0]].astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect='equal',
                                       interpolation='nearest'),
                    legend=False,
                    ax=axs[0, 0])
        axs[0, 0].set_axis_off()
        axs[0, 0].set_title('Axial', color='white')
        imshow_pred(vol[:, :, slice[0]],
                    np.zeros_like(vol)[:, :, slice[0]].astype(bool),
                    trgt[:, :, slice[0]].astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect='equal',
                                       interpolation='nearest'),
                    legend=False,
                    ax=axs[1, 0])
        axs[1, 0].set_axis_off()
        # Sagital
        legend, legend_kwargs = False, None
        if pred_fn is not None or trgt_fn is not None:
            legend = True
            legend_kwargs = dict(loc='upper center',
                                 ncol=2,
                                 frameon=False,
                                 labelcolor='white',
                                 framealpha=0.0,
                                 fontsize=10,
                                 bbox_to_anchor=(0.5, -0.2),
                                 bbox_transform=axs[1, 1].transAxes)
        imshow_pred(np.rot90(vol[:, slice[1], :], axes=(0, 1)),
                    np.rot90(pred[:, slice[1], :], axes=(0, 1)).astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect=aspect_ratio,
                                       interpolation='nearest'),
                    legend=False,
                    ax=axs[0, 1])
        axs[0, 1].set_axis_off()
        axs[0, 1].set_title('Sagital', color='white')
        imshow_pred(np.rot90(vol[:, slice[1], :], axes=(0, 1)),
                    np.rot90(np.zeros_like(vol)[:, slice[1], :],
                             axes=(0, 1)).astype(bool),
                    np.rot90(trgt[:, slice[1], :], axes=(0, 1)).astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect=aspect_ratio,
                                       interpolation='nearest'),
                    legend=legend,
                    legend_kwargs=legend_kwargs,
                    ax=axs[1, 1])
        axs[1, 1].set_axis_off()
        # Coronal
        imshow_pred(np.rot90(vol[slice[2], :, :], axes=(0, 1)),
                    np.rot90(pred[slice[2], :, :], axes=(0, 1)).astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect=aspect_ratio,
                                       interpolation='nearest'),
                    legend=False,
                    ax=axs[0, 2])
        axs[0, 2].set_axis_off()
        axs[0, 2].set_title('Coronal', color='white')
        imshow_pred(np.rot90(vol[slice[2], :, :], axes=(0, 1)),
                    np.rot90(np.zeros_like(vol)[slice[2], :, :],
                             axes=(0, 1)).astype(bool),
                    np.rot90(trgt[slice[2], :, :], axes=(0, 1)).astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect=aspect_ratio,
                                       interpolation='nearest'),
                    legend=False,
                    ax=axs[1, 2])
        axs[1, 2].set_axis_off()
        # 3D rendering
        axs[0, 3].imshow(vol3Drender_pred, cmap='gray')
        axs[0, 3].set_axis_off()
        axs[0, 3].set_title('3D rendering', color='white')
        axs[1, 3].imshow(vol3Drender_trgt, cmap='gray')
        axs[1, 3].set_axis_off()
        # save figure
        fig.set_facecolor('black')
        fig.tight_layout()
        save_fn = save_fn if save_fn else f'A{slice[0]}_S{slice[1]}_C{slice[2]}.pdf'
        fig.savefig(save_fn, dpi=300, bbox_inches='tight')
    else:
        # make 3D rendering
        p = pv.Plotter(off_screen=True, window_size=[512, 512])
        p.background_color = 'black'
        p.add_mesh(surface,
                   opacity=vol_alpha,
                   clim=data.get_data_range(),
                   color='lightgray')
        if pred_fn:
            p.add_mesh(surface_pred, opacity=1, color=pred_color)
        if trgt_fn:
            p.add_mesh(surface_trgt, opacity=1, color=trgt_color)
        if cpos:
            p.camera_position = cpos
        else:
            p.view_isometric()
        _, vol3Drender = p.show(screenshot=True)

        # Make figure
        if pred_fn is None:
            pred = np.zeros_like(vol).astype(bool)
        if trgt_fn is None:
            trgt = np.zeros_like(vol).astype(bool)

        fig, axs = plt.subplots(1, 4, figsize=(10, 6))
        # Axial
        imshow_pred(vol[:, :, slice[0]],
                    pred[:, :, slice[0]].astype(bool),
                    trgt[:, :, slice[0]].astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect='equal',
                                       interpolation='nearest'),
                    legend=False,
                    ax=axs[0])
        axs[0].set_axis_off()
        axs[0].set_title('Axial', color='white')
        # Sagital
        legend, legend_kwargs = False, None
        if pred_fn is not None or trgt_fn is not None:
            legend = True if trgt_fn is not None and pred_fn is not None else False
            legend_kwargs = dict(loc='upper center',
                                 ncol=2,
                                 frameon=False,
                                 labelcolor='white',
                                 framealpha=0.0,
                                 fontsize=10,
                                 bbox_to_anchor=(0.5, -0.1),
                                 bbox_transform=axs[1].transAxes)
        imshow_pred(np.rot90(vol[:, slice[1], :], axes=(0, 1)),
                    np.rot90(pred[:, slice[1], :], axes=(0, 1)).astype(bool),
                    np.rot90(trgt[:, slice[1], :], axes=(0, 1)).astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect=aspect_ratio,
                                       interpolation='nearest'),
                    legend=legend,
                    legend_kwargs=legend_kwargs,
                    ax=axs[1])
        axs[1].set_axis_off()
        axs[1].set_title('Sagital', color='white')
        # Coronal
        imshow_pred(np.rot90(vol[slice[2], :, :], axes=(0, 1)),
                    np.rot90(pred[slice[2], :, :], axes=(0, 1)).astype(bool),
                    np.rot90(trgt[slice[2], :, :], axes=(0, 1)).astype(bool),
                    im_cmap='gray',
                    pred_color=pred_color,
                    pred_alpha=0.8,
                    target_color=trgt_color,
                    target_alpha=0.8,
                    imshow_kwargs=dict(aspect=aspect_ratio,
                                       interpolation='nearest'),
                    legend=False,
                    ax=axs[2])
        axs[2].set_axis_off()
        axs[2].set_title('Coronal', color='white')
        # 3D rendering
        axs[3].imshow(vol3Drender, cmap='gray')
        axs[3].set_axis_off()
        axs[3].set_title('3D rendering', color='white')
        # save figure
        fig.set_facecolor('black')
        fig.tight_layout()
        save_fn = save_fn if save_fn else f'A{slice[0]}_S{slice[1]}_C{slice[2]}.pdf'
        fig.savefig(save_fn, dpi=300, bbox_inches='tight')