def denoising_diagnosis(raw_img,init_denoise,warped_mask,final_denoise, name_source, out_dir):
    import os
    import pathlib
    import SimpleITK as sitk
    filename_template = pathlib.Path(name_source).name.rsplit(".nii")[0]
    os.makedirs(out_dir, exist_ok=True)
    prefix = out_dir+'/'+ \
        filename_template

    import matplotlib.pyplot as plt
    from rabies.preprocess_pkg.preprocess_visual_QC import plot_3d,otsu_scaling
    fig,axes = plt.subplots(nrows=3, ncols=4, figsize=(12*4,2*3))
    plt.tight_layout()

    scaled = otsu_scaling(raw_img)
    axes[0,0].set_title('Raw EPI', fontsize=20)
    plot_3d(axes[:,0],scaled,fig=fig,vmin=0,vmax=1,cmap='viridis')

    scaled = otsu_scaling(init_denoise)
    axes[0,1].set_title('Initial Denoising', fontsize=20)
    plot_3d(axes[:,1],scaled,fig=fig,vmin=0,vmax=1,cmap='viridis')

    axes[0,2].set_title('Resampled Mask', fontsize=20)
    plot_3d(axes[:,2],scaled,fig=fig,vmin=0,vmax=1,cmap='viridis')
    plot_3d(axes[:,2],sitk.ReadImage(warped_mask,sitk.sitkFloat32),fig=fig,vmin=-1,vmax=1,cmap='bwr', alpha=0.3, cbar=False)

    scaled = otsu_scaling(final_denoise)
    axes[0,3].set_title('Final Denoising', fontsize=20)
    plot_3d(axes[:,3],scaled,fig=fig,vmin=0,vmax=1,cmap='viridis')

    fig.savefig('%s_denoising.png' % (prefix), bbox_inches='tight')
def plot_reg(image1,image2, name_source, out_dir):
    import os
    import pathlib
    filename_template = pathlib.Path(name_source).name.rsplit(".nii")[0]
    os.makedirs(out_dir, exist_ok=True)
    prefix = out_dir+'/'+ \
        filename_template

    import matplotlib.pyplot as plt
    from rabies.preprocess_pkg.preprocess_visual_QC import plot_3d,otsu_scaling
    fig,axes = plt.subplots(nrows=2, ncols=3, figsize=(12*3,2*2))
    plt.tight_layout()

    scaled = otsu_scaling(image1)
    display1,display2,display3 = plot_3d(scaled,axes[0,:], cmap='gray')
    display1.add_edges(image2)
    display2.add_edges(image2)
    display3.add_edges(image2)

    scaled = otsu_scaling(image2)
    display1,display2,display3 = plot_3d(scaled,axes[1,:], cmap='gray')
    display1.add_edges(image1)
    display2.add_edges(image1)
    display3.add_edges(image1)
    fig.savefig('%s_registration.png' % (prefix), bbox_inches='tight')
def temporal_features(bold_file, confounds_csv, FD_csv, rabies_data_type, name_source, out_dir):
    import os
    import pathlib
    filename_template = pathlib.Path(name_source).name.rsplit(".nii")[0]
    os.makedirs(out_dir, exist_ok=True)
    prefix = out_dir+'/'+ \
        filename_template

    import numpy as np
    import SimpleITK as sitk
    import matplotlib.pyplot as plt
    from rabies.preprocess_pkg.preprocess_visual_QC import plot_3d
    from rabies.preprocess_pkg.utils import copyInfo_3DImage
    fig,axes = plt.subplots(nrows=3, ncols=3, figsize=(20,5))
    # plot the motion timecourses
    import pandas as pd
    df = pd.read_csv(confounds_csv)
    ax = axes[0,0]
    ax.plot(df['mov1'])
    ax.plot(df['mov2'])
    ax.plot(df['mov3'])
    ax.legend(['mov1','mov2','mov3'])
    ax.set_title('Translation parameters', fontsize=20)
    ax = axes[1,0]
    ax.plot(df['rot1'])
    ax.plot(df['rot2'])
    ax.plot(df['rot3'])
    ax.legend(['rot1','rot2','rot3'])
    ax.set_title('Rotation parameters', fontsize=20)

    df = pd.read_csv(FD_csv)
    ax=axes[2,0]
    ax.plot(df['Mean'], color='r')
    ax.set_title('Framewise Displacement', fontsize=20)

    plt.tight_layout()

    # calculate STD and tSNR map on preprocessed timeseries
    img = sitk.ReadImage(bold_file, rabies_data_type)
    array = sitk.GetArrayFromImage(img)
    mean = array.mean(axis=0)
    std = array.std(axis=0)
    std_filename = os.path.abspath('tSTD.nii.gz')
    std_image = copyInfo_3DImage(
        sitk.GetImageFromArray(std, isVector=False), img)
    sitk.WriteImage(std_image, std_filename)

    tSNR = np.divide(mean, std)
    tSNR[np.isnan(tSNR)]=0
    tSNR_filename = os.path.abspath('tSNR.nii.gz')
    tSNR_image = copyInfo_3DImage(
        sitk.GetImageFromArray(tSNR, isVector=False), img)
    sitk.WriteImage(tSNR_image, tSNR_filename)

    axes[0,1].set_title('Temporal STD', fontsize=20)
    std=std.flatten()
    std.sort()
    std_vmax = std[int(len(std)*0.95)]
    plot_3d(axes[:,1],std_image,fig=fig,vmin=0,vmax=std_vmax,cmap='inferno', cbar=True)
    axes[0,2].set_title('Temporal SNR', fontsize=20)
    plot_3d(axes[:,2],tSNR_image,fig=fig,vmin=0,vmax=tSNR.max(),cmap='Spectral', cbar=True)

    fig.savefig('%s_temporal_features.png' % (prefix), bbox_inches='tight')

    return std_filename, tSNR_filename
def template_info(anat_template, opts, out_dir):
    import os
    from nilearn import plotting
    import matplotlib.pyplot as plt
    from rabies.preprocess_pkg.preprocess_visual_QC import plot_3d,otsu_scaling
    brain_mask = str(opts.brain_mask)
    WM_mask = str(opts.WM_mask)
    CSF_mask = str(opts.CSF_mask)
    vascular_mask = str(opts.vascular_mask)
    labels = str(opts.labels)
    os.makedirs(out_dir, exist_ok=True)

    import SimpleITK as sitk
    # make sure that masks are binary
    for mask in [brain_mask,WM_mask,vascular_mask]:
        img = sitk.ReadImage(mask)
        array = sitk.GetArrayFromImage(img)
        if ((array!=1)*(array!=0)).sum()>0:
            raise ValueError("The file %s is not a binary mask. Non-binary masks cannot be processed." % (mask))

    scaled = otsu_scaling(anat_template)

    fig,axes = plt.subplots(nrows=3, ncols=6, figsize=(4*6,2*2))

    axes[0,0].set_title('Anatomical Template', fontsize=20)
    plot_3d(axes[:,0],scaled,fig=fig,vmin=0,vmax=1,cmap='gray')
    # plot brain mask
    mask = brain_mask
    sitk_mask = sitk.ReadImage(
        mask, sitk.sitkFloat32)
    axes[0,1].set_title('Brain Mask', fontsize=20)
    plot_3d(axes[:,1],scaled,fig=fig,vmin=0,vmax=1,cmap='gray')
    plot_3d(axes[:,1],sitk_mask,fig=fig,vmin=-1,vmax=1,cmap='bwr', alpha=0.3, cbar=False)
    # plot WM mask
    mask = WM_mask
    sitk_mask = sitk.ReadImage(
        mask, sitk.sitkFloat32)
    axes[0,2].set_title('WM Mask', fontsize=20)
    plot_3d(axes[:,2],scaled,fig=fig,vmin=0,vmax=1,cmap='gray')
    plot_3d(axes[:,2],sitk_mask,fig=fig,vmin=-1,vmax=1,cmap='bwr', alpha=0.5, cbar=False)
    # plot CSF mask
    mask = CSF_mask
    sitk_mask = sitk.ReadImage(
        mask, sitk.sitkFloat32)
    axes[0,3].set_title('CSF Mask', fontsize=20)
    plot_3d(axes[:,3],scaled,fig=fig,vmin=0,vmax=1,cmap='gray')
    plot_3d(axes[:,3],sitk_mask,fig=fig,vmin=-1,vmax=1,cmap='bwr', alpha=0.5, cbar=False)
    # plot VASC mask
    mask = vascular_mask
    sitk_mask = sitk.ReadImage(
        mask, sitk.sitkFloat32)
    axes[0,4].set_title('Vascular Mask', fontsize=20)
    plot_3d(axes[:,4],scaled,fig=fig,vmin=0,vmax=1,cmap='gray')
    plot_3d(axes[:,4],sitk_mask,fig=fig,vmin=-1,vmax=1,cmap='bwr', alpha=0.5, cbar=False)

    # plot labels
    mask = labels
    sitk_mask = sitk.ReadImage(
        mask, sitk.sitkFloat32)
    axes[0,5].set_title('Atlas Labels', fontsize=20)
    plot_3d(axes[:,5],scaled,fig=fig,vmin=0,vmax=1,cmap='gray')
    plot_3d(axes[:,5],sitk_mask,fig=fig,vmin=1,vmax=sitk.GetArrayFromImage(sitk_mask).max(),cmap='rainbow', alpha=0.5, cbar=False)
    plt.tight_layout()

    fig.savefig(out_dir+'/template_files.png', bbox_inches='tight')
Exemple #5
0
def scan_diagnosis(bold_file,
                   mask_file_dict,
                   temporal_info,
                   spatial_info,
                   confounds_csv,
                   regional_grayplot=False):
    template_file = mask_file_dict['template_file']
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    fig, axCenter = plt.subplots(figsize=(6, 18))
    fig.subplots_adjust(.2, .1, .8, .95)

    divider = make_axes_locatable(axCenter)
    ax1 = divider.append_axes('bottom', size='25%', pad=0.5)
    ax1_ = divider.append_axes('bottom', size='25%', pad=0.1)
    ax2 = divider.append_axes('bottom', size='50%', pad=0.5)
    ax3 = divider.append_axes('bottom', size='50%', pad=0.5)
    ax4 = divider.append_axes('bottom', size='50%', pad=0.5)

    if regional_grayplot:
        im, slice_alt, region_mask_label = grayplot_regional(
            bold_file, mask_file_dict, fig, axCenter)
        axCenter.yaxis.labelpad = 40
        ax_slice = divider.append_axes('left', size='5%', pad=0.0)
        ax_label = divider.append_axes('left', size='5%', pad=0.0)

        ax_slice.imshow(slice_alt.reshape(-1, 1),
                        cmap='gray',
                        vmin=0,
                        vmax=1.1,
                        aspect='auto')
        ax_label.imshow(region_mask_label.reshape(-1, 1),
                        cmap='Spectral',
                        aspect='auto')
        ax_slice.axis('off')
        ax_label.axis('off')

    else:
        im = grayplot(bold_file, mask_file_dict, fig, axCenter)

    axCenter.set_ylabel('Voxels', fontsize=20)
    axCenter.spines['right'].set_visible(False)
    axCenter.spines['top'].set_visible(False)
    axCenter.spines['bottom'].set_visible(False)
    axCenter.spines['left'].set_visible(False)
    axCenter.axes.get_yaxis().set_ticks([])
    plt.setp(axCenter.get_xticklabels(), visible=False)

    # plot the motion timecourses
    import pandas as pd
    df = pd.read_csv(confounds_csv)
    ax1.plot(df['mov1'])
    ax1.plot(df['mov2'])
    ax1.plot(df['mov3'])
    ax1.legend(['translation 1', 'translation 2', 'translation 3'],
               loc='center left',
               bbox_to_anchor=(1, 0.5))
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    plt.setp(ax1.get_xticklabels(), visible=False)

    ax1_.plot(df['rot1'])
    ax1_.plot(df['rot2'])
    ax1_.plot(df['rot3'])
    ax1_.legend(['rotation 1', 'rotation 2', 'rotation 3'],
                loc='center left',
                bbox_to_anchor=(1, 0.5))
    plt.setp(ax1_.get_xticklabels(), visible=False)
    ax1_.spines['right'].set_visible(False)
    ax1_.spines['top'].set_visible(False)

    #ax1.set_title(name, fontsize=15)
    y = temporal_info['FD_trace']
    ax2.plot(y, 'r')
    ax2.set_xlim([0, len(y)])
    ax2.legend(['Framewise Displacement (FD)'], loc='upper right')
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.set_ylim([0.0, 0.1])
    plt.setp(ax2.get_xticklabels(), visible=False)

    DVARS = temporal_info['DVARS']
    DVARS[0] = None
    y = DVARS
    ax3.plot(y)
    ax3.set_xlim([0, len(y)])
    ax3.plot(temporal_info['edge_trace'])
    ax3.plot(temporal_info['WM_trace'])
    ax3.plot(temporal_info['CSF_trace'])
    ax3.plot(temporal_info['VE_temporal'])
    ax3.legend(['DVARS', 'Edge Mask', 'WM Mask', 'CSF Mask', 'CR R^2'],
               loc='center left',
               bbox_to_anchor=(1, 0.5))
    ax3.spines['right'].set_visible(False)
    ax3.spines['top'].set_visible(False)
    ax3.set_ylim([0.0, 1.5])
    plt.setp(ax3.get_xticklabels(), visible=False)

    y = temporal_info['signal_trace']
    ax4.plot(y)
    ax4.set_xlim([0, len(y)])
    ax4.plot(temporal_info['noise_trace'])
    ax4.legend(['BOLD components', 'Confound components'], loc='upper right')
    ax4.spines['right'].set_visible(False)
    ax4.spines['top'].set_visible(False)
    ax4.set_ylim([0.0, 4.0])
    ax4.set_xlabel('Timepoint', fontsize=15)

    dr_maps = spatial_info['DR_BOLD']
    mask_file = mask_file_dict['brain_mask']

    nrows = 5 + dr_maps.shape[0]

    fig2, axes2 = plt.subplots(nrows=nrows,
                               ncols=3,
                               figsize=(12 * 3, 2 * nrows))
    plt.tight_layout()

    from rabies.preprocess_pkg.preprocess_visual_QC import plot_3d, otsu_scaling

    axes = axes2[0, :]
    scaled = otsu_scaling(template_file)
    plot_3d(axes,
            scaled,
            fig2,
            vmin=0,
            vmax=1,
            cmap='gray',
            alpha=1,
            cbar=False,
            num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['temporal_std']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    plot_3d(axes,
            sitk_img,
            fig2,
            vmin=0,
            vmax=1,
            cmap='inferno',
            alpha=1,
            cbar=True,
            num_slices=6)
    for ax in axes:
        ax.set_title('Temporal STD', fontsize=25)

    axes = axes2[1, :]
    plot_3d(axes,
            scaled,
            fig2,
            vmin=0,
            vmax=1,
            cmap='gray',
            alpha=1,
            cbar=False,
            num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['VE_spatial']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    plot_3d(axes,
            sitk_img,
            fig2,
            vmin=0,
            vmax=1,
            cmap='inferno',
            alpha=1,
            cbar=True,
            threshold=0.1,
            num_slices=6)
    for ax in axes:
        ax.set_title('CR R^2', fontsize=25)

    axes = axes2[2, :]
    plot_3d(axes,
            scaled,
            fig2,
            vmin=0,
            vmax=1,
            cmap='gray',
            alpha=1,
            cbar=False,
            num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['GS_corr']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    plot_3d(axes,
            sitk_img,
            fig2,
            vmin=-1,
            vmax=1,
            cmap='cold_hot',
            alpha=1,
            cbar=True,
            threshold=0.1,
            num_slices=6)
    for ax in axes:
        ax.set_title('Global Signal Correlation', fontsize=25)

    axes = axes2[3, :]
    plot_3d(axes,
            scaled,
            fig2,
            vmin=0,
            vmax=1,
            cmap='gray',
            alpha=1,
            cbar=False,
            num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['DVARS_corr']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    plot_3d(axes,
            sitk_img,
            fig2,
            vmin=-1,
            vmax=1,
            cmap='cold_hot',
            alpha=1,
            cbar=True,
            threshold=0.1,
            num_slices=6)
    for ax in axes:
        ax.set_title('DVARS Correlation', fontsize=25)

    axes = axes2[4, :]
    plot_3d(axes,
            scaled,
            fig2,
            vmin=0,
            vmax=1,
            cmap='gray',
            alpha=1,
            cbar=False,
            num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['FD_corr']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    plot_3d(axes,
            sitk_img,
            fig2,
            vmin=-1,
            vmax=1,
            cmap='cold_hot',
            alpha=1,
            cbar=True,
            threshold=0.1,
            num_slices=6)
    for ax in axes:
        ax.set_title('FD Correlation', fontsize=25)

    for i in range(dr_maps.shape[0]):
        axes = axes2[i + 5, :]
        plot_3d(axes,
                scaled,
                fig2,
                vmin=0,
                vmax=1,
                cmap='gray',
                alpha=1,
                cbar=False,
                num_slices=6)
        analysis_functions.recover_3D(
            mask_file, dr_maps[i, :]).to_filename('temp_img.nii.gz')
        sitk_img = sitk.ReadImage('temp_img.nii.gz')
        plot_3d(axes,
                sitk_img,
                fig2,
                vmin=-1,
                vmax=1,
                cmap='cold_hot',
                alpha=1,
                cbar=True,
                threshold=0.1,
                num_slices=6)
        for ax in axes:
            ax.set_title('BOLD component %s' % (i), fontsize=25)

    return fig, fig2
Exemple #6
0
    def _run_interface(self, runtime):
        from rabies.preprocess_pkg.utils import flatten_list
        merged = flatten_list(list(self.inputs.spatial_info_list))
        if len(merged) < 3:
            import logging
            log = logging.getLogger('root')
            log.warning(
                "Cannot run statistics on a sample size smaller than 3, so an empty figure is generated."
            )
            fig, axes = plt.subplots()
            fig.savefig(os.path.abspath('empty_dataset_diagnosis.png'),
                        bbox_inches='tight')

            setattr(self, 'figure_dataset_diagnosis',
                    os.path.abspath('empty_dataset_diagnosis.png'))
            return runtime

        dict_keys = [
            'temporal_std', 'VE_spatial', 'GS_corr', 'DVARS_corr', 'FD_corr',
            'DR_BOLD', 'prior_modeling_maps'
        ]

        voxelwise_list = []
        for spatial_info in merged:
            sub_list = [spatial_info[key] for key in dict_keys]
            voxelwise_sub = np.array(sub_list[:5])
            if len(sub_list[6]) > 0:
                voxelwise_sub = np.concatenate(
                    (voxelwise_sub, np.array(sub_list[5]), np.array(
                        sub_list[6])),
                    axis=0)
            else:
                voxelwise_sub = np.concatenate(
                    (voxelwise_sub, np.array(sub_list[5])), axis=0)
            voxelwise_list.append(voxelwise_sub)
            num_DR_maps = len(sub_list[5])
            num_prior_maps = len(sub_list[6])
        voxelwise_array = np.array(voxelwise_list)

        label_name = [
            'temporal_std', 'VE_spatial', 'GS_corr', 'DVARS_corr', 'FD_corr'
        ]
        label_name += [
            'BOLD Dual Regression map %s' % (i) for i in range(num_DR_maps)
        ]
        label_name += [
            'BOLD Dual Convergence map %s' % (i) for i in range(num_prior_maps)
        ]

        template_file = self.inputs.mask_file_dict['template_file']
        mask_file = self.inputs.mask_file_dict['brain_mask']
        from rabies.preprocess_pkg.preprocess_visual_QC import plot_3d, otsu_scaling
        scaled = otsu_scaling(template_file)

        ncols = 5
        fig, axes = plt.subplots(nrows=voxelwise_array.shape[1],
                                 ncols=ncols,
                                 figsize=(12 * ncols,
                                          2 * voxelwise_array.shape[1]))
        for i, x_label in zip(range(voxelwise_array.shape[1]), label_name):
            for j, y_label in zip(range(ncols), label_name[:ncols]):
                ax = axes[i, j]
                if i <= j:
                    ax.axis('off')
                    continue

                X = voxelwise_array[:, i, :]
                Y = voxelwise_array[:, j, :]
                corr = elementwise_corrcoef(X, Y)

                plot_3d([ax],
                        scaled,
                        fig,
                        vmin=0,
                        vmax=1,
                        cmap='gray',
                        alpha=1,
                        cbar=False,
                        num_slices=6,
                        planes=('coronal'))
                analysis_functions.recover_3D(
                    mask_file, corr).to_filename('temp_img.nii.gz')
                sitk_img = sitk.ReadImage('temp_img.nii.gz')
                plot_3d([ax],
                        sitk_img,
                        fig,
                        vmin=-0.7,
                        vmax=0.7,
                        cmap='cold_hot',
                        alpha=1,
                        cbar=True,
                        threshold=0.1,
                        num_slices=6,
                        planes=('coronal'))
                ax.set_title('Cross-correlation for %s and %s' %
                             (x_label, y_label),
                             fontsize=15)
        fig.savefig(os.path.abspath('dataset_diagnosis.png'),
                    bbox_inches='tight')

        setattr(self, 'figure_dataset_diagnosis',
                os.path.abspath('dataset_diagnosis.png'))
        return runtime