Ejemplo n.º 1
0
def plot_reg(image1, image2, name_source, out_dir):
    import os
    import pathlib
    filename_template = pathlib.Path(name_source).name.rsplit(".nii")[0]
    os.makedirs(out_dir, exist_ok=True)
    prefix = out_dir+'/'+ \
        filename_template

    import matplotlib.pyplot as plt
    from rabies.visualization import plot_3d, otsu_scaling
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12 * 3, 2 * 2))
    plt.tight_layout()

    scaled = otsu_scaling(image1)
    display1, display2, display3 = plot_3d(scaled, axes[0, :], cmap='gray')
    display1.add_edges(image2)
    display2.add_edges(image2)
    display3.add_edges(image2)

    scaled = otsu_scaling(image2)
    display1, display2, display3 = plot_3d(scaled, axes[1, :], cmap='gray')
    display1.add_edges(image1)
    display2.add_edges(image1)
    display3.add_edges(image1)
    fig.savefig(f'{prefix}_registration.png', bbox_inches='tight')
Ejemplo n.º 2
0
def template_masking(template, mask, out_dir):
    import os
    import SimpleITK as sitk
    # set default threader to platform to avoid freezing with MultiProc https://github.com/SimpleITK/SimpleITK/issues/1239
    sitk.ProcessObject_SetGlobalDefaultThreader('Platform')
    from nilearn import plotting
    import matplotlib.pyplot as plt
    from rabies.visualization import plot_3d, otsu_scaling

    os.makedirs(out_dir, exist_ok=True)

    scaled = otsu_scaling(template)

    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(4, 2 * 2))

    # plot brain mask
    sitk_mask = sitk.ReadImage(mask, sitk.sitkFloat32)
    # resample mask to match template
    sitk_mask = sitk.Resample(sitk_mask, scaled)
    plot_3d(axes[:], scaled, fig=fig, vmin=0, vmax=1, cmap='gray')
    plot_3d(axes[:],
            sitk_mask,
            fig=fig,
            vmin=-1,
            vmax=1,
            cmap='bwr',
            alpha=0.3,
            cbar=False)
    plt.tight_layout()
    fig.savefig(out_dir + '/template_masking.png', bbox_inches='tight')
Ejemplo n.º 3
0
def inho_cor_diagnosis(raw_img, init_denoise, warped_mask, final_denoise,
                       name_source, out_dir):
    import os
    import pathlib
    import SimpleITK as sitk
    # set default threader to platform to avoid freezing with MultiProc https://github.com/SimpleITK/SimpleITK/issues/1239
    sitk.ProcessObject_SetGlobalDefaultThreader('Platform')
    filename_template = pathlib.Path(name_source).name.rsplit(".nii")[0]
    os.makedirs(out_dir, exist_ok=True)
    prefix = out_dir+'/'+ \
        filename_template

    import matplotlib.pyplot as plt
    from rabies.visualization import plot_3d, otsu_scaling
    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(12 * 4, 2 * 3))

    scaled = otsu_scaling(raw_img)
    axes[0, 0].set_title('Raw Image', fontsize=30, color='white')
    #add_filenames(axes[-1,0], {'File':raw_img})
    plot_3d(axes[:, 0], scaled, fig=fig, vmin=0, vmax=1, cmap='viridis')

    axes[0, 2].set_title('Resampled Mask', fontsize=30, color='white')
    #add_filenames(axes[-1,2], {'Mask File':warped_mask,'EPI File':raw_img})
    plot_3d(axes[:, 2], scaled, fig=fig, vmin=0, vmax=1, cmap='viridis')
    sitk_mask = sitk.ReadImage(warped_mask, sitk.sitkFloat32)
    # resample mask to match template
    sitk_mask = sitk.Resample(sitk_mask, scaled)
    plot_3d(axes[:, 2],
            sitk_mask,
            fig=fig,
            vmin=-1,
            vmax=1,
            cmap='bwr',
            alpha=0.3,
            cbar=False)

    scaled = otsu_scaling(init_denoise)
    axes[0, 1].set_title('Initial Correction', fontsize=30, color='white')
    #add_filenames(axes[-1,1], {'File':init_denoise})
    plot_3d(axes[:, 1], scaled, fig=fig, vmin=0, vmax=1, cmap='viridis')

    scaled = otsu_scaling(final_denoise)
    axes[0, 3].set_title('Final Correction', fontsize=30, color='white')
    #add_filenames(axes[-1,3], {'File':final_denoise})
    plot_3d(axes[:, 3], scaled, fig=fig, vmin=0, vmax=1, cmap='viridis')

    plt.tight_layout()
    fig.savefig(f'{prefix}_inho_cor.png', bbox_inches='tight')
Ejemplo n.º 4
0
def masked_plot(fig, axes, img, scaled, percentile=0.01, vmax=None):
    mask = percent_masking(img, percentile=percentile)

    masked = sitk.GetImageFromArray(sitk.GetArrayFromImage(img) * mask)
    masked.CopyInformation(img)

    data = sitk.GetArrayFromImage(img)
    if vmax is None:
        vmax = data.max()

    if not (type(axes) is np.ndarray or type(axes) is list):
        axes = [axes]
        planes = ('coronal')
    else:
        planes = ('sagittal', 'coronal', 'horizontal')
    plot_3d(axes,
            scaled,
            fig,
            vmin=0,
            vmax=1,
            cmap='gray',
            alpha=1,
            cbar=False,
            num_slices=6,
            planes=planes)
    # resample to match template
    sitk_img = sitk.Resample(masked, scaled)
    cbar_list = plot_3d(axes,
                        sitk_img,
                        fig,
                        vmin=-vmax,
                        vmax=vmax,
                        cmap='cold_hot',
                        alpha=1,
                        cbar=True,
                        threshold=vmax * 0.001,
                        num_slices=6,
                        planes=planes)
    return cbar_list
Ejemplo n.º 5
0
def scan_diagnosis(bold_file, mask_file_dict, temporal_info, spatial_info, CR_data_dict, regional_grayplot=False):
    template_file = mask_file_dict['template_file']
    
    fig = plt.figure(figsize=(6, 18))
    #fig.suptitle(name, fontsize=30, color='white')
    
    ax0 = fig.add_subplot(3,1,1)
    ax1 = fig.add_subplot(12,1,5)
    ax1_ = fig.add_subplot(12,1,6)
    ax2 = fig.add_subplot(6,1,4)
    ax3 = fig.add_subplot(6,1,5)
    ax4 = fig.add_subplot(6,1,6)

    # disable function
    regional_grayplot=False
    if regional_grayplot:
        
        from mpl_toolkits.axes_grid1 import make_axes_locatable
        divider = make_axes_locatable(ax0)
        
        im, slice_alt, region_mask_label = grayplot_regional(
            bold_file, mask_file_dict, fig, ax0)
        ax0.yaxis.labelpad = 40
        ax_slice = divider.append_axes('left', size='5%', pad=0.0)
        ax_label = divider.append_axes('left', size='5%', pad=0.0)

        ax_slice.imshow(slice_alt.reshape(-1, 1), cmap='gray',
                        vmin=0, vmax=1.1, aspect='auto')
        ax_label.imshow(region_mask_label.reshape(-1, 1),
                        cmap='Spectral', aspect='auto')
        ax_slice.axis('off')
        ax_label.axis('off')

    else:
        im = grayplot(bold_file, mask_file_dict, fig, ax0)

    ax0.set_ylabel('Voxels', fontsize=20)
    ax0.spines['right'].set_visible(False)
    ax0.spines['top'].set_visible(False)
    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.axes.get_yaxis().set_ticks([])
    plt.setp(ax1.get_xticklabels(), visible=False)

    y = temporal_info['FD_trace'].to_numpy()
    x = range(len(y))
    ax0.set_xlim([0, len(y)-1])
    ax1.set_xlim([0, len(y)-1])
    ax1_.set_xlim([0, len(y)-1])
    ax2.set_xlim([0, len(y)-1])
    ax3.set_xlim([0, len(y)-1])
    ax4.set_xlim([0, len(y)-1])

    # plot the motion timecourses
    confounds_csv = CR_data_dict['confounds_csv']
    time_range = CR_data_dict['time_range']
    frame_mask = CR_data_dict['frame_mask']
    df = pd.read_csv(confounds_csv)
    # take proper subset of timepoints
    ax1.plot(x,df['mov1'].to_numpy()[time_range][frame_mask])
    ax1.plot(x,df['mov2'].to_numpy()[time_range][frame_mask])
    ax1.plot(x,df['mov3'].to_numpy()[time_range][frame_mask])
    ax1.legend(['translation 1', 'translation 2', 'translation 3'],
               loc='center left', bbox_to_anchor=(1.15, 0.5))
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    plt.setp(ax1.get_xticklabels(), visible=False)

    ax1_.plot(x,df['rot1'].to_numpy()[time_range][frame_mask])
    ax1_.plot(x,df['rot2'].to_numpy()[time_range][frame_mask])
    ax1_.plot(x,df['rot3'].to_numpy()[time_range][frame_mask])
    ax1_.legend(['rotation 1', 'rotation 2', 'rotation 3'],
                loc='center left', bbox_to_anchor=(1.15, 0.5))
    plt.setp(ax1_.get_xticklabels(), visible=False)
    ax1_.spines['right'].set_visible(False)
    ax1_.spines['top'].set_visible(False)

    y = temporal_info['FD_trace'].to_numpy()
    ax2.plot(x,y, 'r')
    ax2.set_ylabel('FD in mm', fontsize=15)
    DVARS = temporal_info['DVARS']
    DVARS[0] = None
    ax2_ = ax2.twinx()
    y2 = DVARS
    ax2_.plot(x,y2, 'b')
    ax2_.set_ylabel('DVARS', fontsize=15)
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2_.spines['top'].set_visible(False)
    plt.setp(ax2.get_xticklabels(), visible=False)
    plt.setp(ax2_.get_xticklabels(), visible=False)
    ax2.legend(['Framewise \nDisplacement (FD)'
                ], loc='center left', bbox_to_anchor=(1.15, 0.6))
    ax2_.legend(['DVARS'
                ], loc='center left', bbox_to_anchor=(1.15, 0.4))

    ax3.plot(x,temporal_info['edge_trace'])
    ax3.plot(x,temporal_info['WM_trace'])
    ax3.plot(x,temporal_info['CSF_trace'])
    ax3.plot(x,temporal_info['predicted_time'])
    ax3.set_ylabel('Mask L2-norm', fontsize=15)
    ax3_ = ax3.twinx()
    ax3_.plot(x,temporal_info['VE_temporal'], 'darkviolet')
    ax3_.set_ylabel('CR R^2', fontsize=15)
    ax3_.spines['right'].set_visible(False)
    ax3_.spines['top'].set_visible(False)
    plt.setp(ax3_.get_xticklabels(), visible=False)
    ax3.legend(['Edge Mask', 'WM Mask', 'CSF Mask', 'CR prediction'
                ], loc='center left', bbox_to_anchor=(1.15, 0.7))
    ax3_.legend(['CR R^2'
                ], loc='center left', bbox_to_anchor=(1.15, 0.3))
    ax3_.set_ylim([0,1])

    y = temporal_info['signal_trace']
    ax4.plot(x,y)
    ax4.plot(x,temporal_info['noise_trace'])
    ax4.legend(['BOLD components', 'Confound components'
                ], loc='center left', bbox_to_anchor=(1.15, 0.5))
    ax4.spines['right'].set_visible(False)
    ax4.spines['top'].set_visible(False)
    ax4.set_xlabel('Timepoint', fontsize=15)
    ax4.set_ylabel('Abs. Beta \ncoefficients (Avg.)', fontsize=15)

    dr_maps = spatial_info['DR_BOLD']
    mask_file = mask_file_dict['brain_mask']

    nrows = 6+dr_maps.shape[0]

    fig2, axes2 = plt.subplots(nrows=nrows, ncols=3, figsize=(12*3, 2*nrows))
    plt.tight_layout()

    from rabies.visualization import otsu_scaling, plot_3d

    axes = axes2[0, :]
    scaled = otsu_scaling(template_file)
    plot_3d(axes, scaled, fig2, vmin=0, vmax=1,
            cmap='gray', alpha=1, cbar=False, num_slices=6)
    temporal_std = spatial_info['temporal_std']
    analysis_functions.recover_3D(
        mask_file, temporal_std).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')

    # select vmax at 95th percentile value
    vector = temporal_std.flatten()
    vector.sort()
    vmax = vector[int(len(vector)*0.95)]
    cbar_list = plot_3d(axes, sitk_img, fig2, vmin=0, vmax=vmax,
            cmap='inferno', alpha=1, cbar=True, num_slices=6)
    for cbar in cbar_list:
        cbar.ax.get_yaxis().labelpad = 35
        cbar.set_label('Standard \n Deviation', fontsize=15, rotation=270, color='white')
    for ax in axes:
        ax.set_title('BOLD-Temporal s.d.', fontsize=30, color='white')


    axes = axes2[1, :]
    scaled = otsu_scaling(template_file)
    plot_3d(axes, scaled, fig2, vmin=0, vmax=1,
            cmap='gray', alpha=1, cbar=False, num_slices=6)
    predicted_std = spatial_info['predicted_std']
    analysis_functions.recover_3D(
        mask_file, predicted_std).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')

    # select vmax at 95th percentile value
    vector = predicted_std.flatten()
    vector.sort()
    vmax = vector[int(len(vector)*0.95)]
    cbar_list = plot_3d(axes, sitk_img, fig2, vmin=0, vmax=vmax,
            cmap='inferno', alpha=1, cbar=True, num_slices=6)
    for cbar in cbar_list:
        cbar.ax.get_yaxis().labelpad = 35
        cbar.set_label('Standard \n Deviation', fontsize=15, rotation=270, color='white')
    for ax in axes:
        ax.set_title('CR-Temporal s.d.', fontsize=30, color='white')


    axes = axes2[2, :]
    plot_3d(axes, scaled, fig2, vmin=0, vmax=1,
            cmap='gray', alpha=1, cbar=False, num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['VE_spatial']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    cbar_list = plot_3d(axes, sitk_img, fig2, vmin=0, vmax=1, cmap='inferno',
            alpha=1, cbar=True, threshold=0.1, num_slices=6)
    for cbar in cbar_list:
        cbar.ax.get_yaxis().labelpad = 20
        cbar.set_label('CR R^2', fontsize=15, rotation=270, color='white')
    for ax in axes:
        ax.set_title('CR R^2', fontsize=30, color='white')

    axes = axes2[3, :]
    plot_3d(axes, scaled, fig2, vmin=0, vmax=1,
            cmap='gray', alpha=1, cbar=False, num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['GS_corr']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    cbar_list = plot_3d(axes, sitk_img, fig2, vmin=-1, vmax=1, cmap='cold_hot',
            alpha=1, cbar=True, threshold=0.1, num_slices=6)
    for cbar in cbar_list:
        cbar.ax.get_yaxis().labelpad = 20
        cbar.set_label("Pearson's' r", fontsize=15, rotation=270, color='white')
    for ax in axes:
        ax.set_title('Global Signal Correlation', fontsize=30, color='white')

    axes = axes2[4, :]
    plot_3d(axes, scaled, fig2, vmin=0, vmax=1,
            cmap='gray', alpha=1, cbar=False, num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['DVARS_corr']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    cbar_list = plot_3d(axes, sitk_img, fig2, vmin=-1, vmax=1, cmap='cold_hot',
            alpha=1, cbar=True, threshold=0.1, num_slices=6)
    for cbar in cbar_list:
        cbar.ax.get_yaxis().labelpad = 20
        cbar.set_label("Pearson's' r", fontsize=15, rotation=270, color='white')
    for ax in axes:
        ax.set_title('DVARS Correlation', fontsize=30, color='white')

    axes = axes2[5, :]
    plot_3d(axes, scaled, fig2, vmin=0, vmax=1,
            cmap='gray', alpha=1, cbar=False, num_slices=6)
    analysis_functions.recover_3D(
        mask_file, spatial_info['FD_corr']).to_filename('temp_img.nii.gz')
    sitk_img = sitk.ReadImage('temp_img.nii.gz')
    cbar_list = plot_3d(axes, sitk_img, fig2, vmin=-1, vmax=1, cmap='cold_hot',
            alpha=1, cbar=True, threshold=0.1, num_slices=6)
    for cbar in cbar_list:
        cbar.ax.get_yaxis().labelpad = 20
        cbar.set_label("Pearson's' r", fontsize=15, rotation=270, color='white')
    for ax in axes:
        ax.set_title('FD Correlation', fontsize=30, color='white')

    for i in range(dr_maps.shape[0]):
        axes = axes2[i+6, :]

        analysis_functions.recover_3D(
            mask_file, dr_maps[i, :]).to_filename('temp_img.nii.gz')
        sitk_img = sitk.ReadImage('temp_img.nii.gz')
        cbar_list = masked_plot(fig2,axes, sitk_img, scaled, percentile=0.015, vmax=None)

        for cbar in cbar_list:
            cbar.ax.get_yaxis().labelpad = 35
            cbar.set_label("Beta \nCoefficient", fontsize=15, rotation=270, color='white')
        for ax in axes:
            ax.set_title(f'BOLD component {i}', fontsize=30, color='white')

    return fig, fig2
Ejemplo n.º 6
0
def spatial_crosscorrelations(merged, scaled, mask_file, fig_path):

    dict_keys = [
        'temporal_std', 'predicted_std', 'VE_spatial', 'GS_corr', 'DVARS_corr',
        'FD_corr', 'DR_BOLD', 'dual_ICA_maps'
    ]

    voxelwise_list = []
    for scan_data in merged:
        sub_list = [scan_data[key] for key in dict_keys]
        voxelwise_sub = np.array(sub_list[:6])
        if len(sub_list[7]) > 0:
            voxelwise_sub = np.concatenate(
                (voxelwise_sub, np.array(sub_list[6]), np.array(sub_list[7])),
                axis=0)
        else:
            voxelwise_sub = np.concatenate(
                (voxelwise_sub, np.array(sub_list[6])), axis=0)
        voxelwise_list.append(voxelwise_sub)
        num_prior_maps = len(sub_list[6])
    voxelwise_array = np.array(voxelwise_list)

    label_name = [
        'BOLD-Temporal s.d.', 'CR-Temporal s.d.', 'CR R^2', 'GS corr',
        'DVARS corr', 'FD corr'
    ]
    label_name += [
        f'BOLD Dual Regression map {i}' for i in range(num_prior_maps)
    ]
    label_name += [f'BOLD Dual ICA map {i}' for i in range(num_prior_maps)]

    ncols = 6
    fig, axes = plt.subplots(nrows=voxelwise_array.shape[1],
                             ncols=ncols,
                             figsize=(12 * ncols,
                                      2 * voxelwise_array.shape[1]))
    for i, x_label in zip(range(voxelwise_array.shape[1]), label_name):
        for j, y_label in zip(range(ncols), label_name[:ncols]):
            ax = axes[i, j]
            if i <= j:
                ax.axis('off')
                continue

            X = voxelwise_array[:, i, :]
            Y = voxelwise_array[:, j, :]
            corr = elementwise_spearman(X, Y)

            plot_3d([ax],
                    scaled,
                    fig,
                    vmin=0,
                    vmax=1,
                    cmap='gray',
                    alpha=1,
                    cbar=False,
                    num_slices=6,
                    planes=('coronal'))
            recover_3D(mask_file, corr).to_filename('temp_img.nii.gz')
            sitk_img = sitk.ReadImage('temp_img.nii.gz')
            cbar_list = plot_3d([ax],
                                sitk_img,
                                fig,
                                vmin=-1.0,
                                vmax=1.0,
                                cmap='cold_hot',
                                alpha=1,
                                cbar=True,
                                threshold=0.1,
                                num_slices=6,
                                planes=('coronal'))
            ax.set_title(f'Cross-correlation for \n{x_label} and {y_label}',
                         fontsize=20,
                         color='white')
            for cbar in cbar_list:
                cbar.ax.get_yaxis().labelpad = 20
                cbar.set_label("Spearman rho",
                               fontsize=12,
                               rotation=270,
                               color='white')

    fig.savefig(fig_path, bbox_inches='tight')
Ejemplo n.º 7
0
def template_info(anat_template, opts, out_dir):
    import os
    import SimpleITK as sitk
    # set default threader to platform to avoid freezing with MultiProc https://github.com/SimpleITK/SimpleITK/issues/1239
    sitk.ProcessObject_SetGlobalDefaultThreader('Platform')
    from nilearn import plotting
    import matplotlib.pyplot as plt
    from rabies.visualization import plot_3d, otsu_scaling
    brain_mask = str(opts.brain_mask)
    WM_mask = str(opts.WM_mask)
    CSF_mask = str(opts.CSF_mask)
    vascular_mask = str(opts.vascular_mask)
    labels = str(opts.labels)
    os.makedirs(out_dir, exist_ok=True)

    scaled = otsu_scaling(anat_template)

    fig, axes = plt.subplots(nrows=3, ncols=6, figsize=(4 * 6, 2 * 2))

    axes[0, 0].set_title('Anatomical Template', fontsize=30, color='white')
    plot_3d(axes[:, 0], scaled, fig=fig, vmin=0, vmax=1, cmap='gray')
    # plot brain mask
    mask = brain_mask
    sitk_mask = sitk.ReadImage(mask, sitk.sitkFloat32)
    # resample mask to match template
    sitk_mask = sitk.Resample(sitk_mask, scaled)
    axes[0, 1].set_title('Brain Mask', fontsize=30, color='white')
    plot_3d(axes[:, 1], scaled, fig=fig, vmin=0, vmax=1, cmap='gray')
    plot_3d(axes[:, 1],
            sitk_mask,
            fig=fig,
            vmin=-1,
            vmax=1,
            cmap='bwr',
            alpha=0.3,
            cbar=False)
    # plot WM mask
    mask = WM_mask
    sitk_mask = sitk.ReadImage(mask, sitk.sitkFloat32)
    # resample mask to match template
    sitk_mask = sitk.Resample(sitk_mask, scaled)
    axes[0, 2].set_title('WM Mask', fontsize=30, color='white')
    plot_3d(axes[:, 2], scaled, fig=fig, vmin=0, vmax=1, cmap='gray')
    plot_3d(axes[:, 2],
            sitk_mask,
            fig=fig,
            vmin=-1,
            vmax=1,
            cmap='bwr',
            alpha=0.5,
            cbar=False)
    # plot CSF mask
    mask = CSF_mask
    sitk_mask = sitk.ReadImage(mask, sitk.sitkFloat32)
    # resample mask to match template
    sitk_mask = sitk.Resample(sitk_mask, scaled)
    axes[0, 3].set_title('CSF Mask', fontsize=30, color='white')
    plot_3d(axes[:, 3], scaled, fig=fig, vmin=0, vmax=1, cmap='gray')
    plot_3d(axes[:, 3],
            sitk_mask,
            fig=fig,
            vmin=-1,
            vmax=1,
            cmap='bwr',
            alpha=0.5,
            cbar=False)
    # plot VASC mask
    mask = vascular_mask
    sitk_mask = sitk.ReadImage(mask, sitk.sitkFloat32)
    # resample mask to match template
    sitk_mask = sitk.Resample(sitk_mask, scaled)
    axes[0, 4].set_title('Vascular Mask', fontsize=30, color='white')
    plot_3d(axes[:, 4], scaled, fig=fig, vmin=0, vmax=1, cmap='gray')
    plot_3d(axes[:, 4],
            sitk_mask,
            fig=fig,
            vmin=-1,
            vmax=1,
            cmap='bwr',
            alpha=0.5,
            cbar=False)

    # plot labels
    mask = labels
    sitk_mask = sitk.ReadImage(mask, sitk.sitkFloat32)
    # resample mask to match template
    sitk_mask = sitk.Resample(sitk_mask, scaled)
    axes[0, 5].set_title('Atlas Labels', fontsize=30, color='white')
    plot_3d(axes[:, 5], scaled, fig=fig, vmin=0, vmax=1, cmap='gray')
    plot_3d(axes[:, 5],
            sitk_mask,
            fig=fig,
            vmin=1,
            vmax=sitk.GetArrayFromImage(sitk_mask).max(),
            cmap='rainbow',
            alpha=0.5,
            cbar=False)
    plt.tight_layout()

    fig.savefig(out_dir + '/template_files.png', bbox_inches='tight')
Ejemplo n.º 8
0
def temporal_features(bold_file, confounds_csv, FD_csv, rabies_data_type,
                      name_source, out_dir):
    import os
    import pathlib
    filename_template = pathlib.Path(name_source).name.rsplit(".nii")[0]
    os.makedirs(out_dir, exist_ok=True)
    prefix = out_dir+'/'+ \
        filename_template

    import numpy as np
    import SimpleITK as sitk
    import matplotlib.pyplot as plt
    from rabies.visualization import plot_3d
    from rabies.utils import copyInfo_3DImage
    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(20, 5))
    # plot the motion timecourses
    import pandas as pd
    df = pd.read_csv(confounds_csv)
    ax = axes[0, 0]
    ax.plot(df['mov1'])
    ax.plot(df['mov2'])
    ax.plot(df['mov3'])
    ax.legend(['mov1', 'mov2', 'mov3'])
    ax.set_title('Translation parameters', fontsize=30, color='white')
    ax = axes[1, 0]
    ax.plot(df['rot1'])
    ax.plot(df['rot2'])
    ax.plot(df['rot3'])
    ax.legend(['rot1', 'rot2', 'rot3'])
    ax.set_title('Rotation parameters', fontsize=30, color='white')

    df = pd.read_csv(FD_csv)
    ax = axes[2, 0]
    ax.plot(df['Mean'], color='r')
    ax.set_title('Framewise Displacement', fontsize=30, color='white')

    plt.tight_layout()

    # calculate STD and tSNR map on preprocessed timeseries
    img = sitk.ReadImage(bold_file, rabies_data_type)
    array = sitk.GetArrayFromImage(img)
    mean = array.mean(axis=0)
    std = array.std(axis=0)
    std_filename = os.path.abspath('tSTD.nii.gz')
    std_image = copyInfo_3DImage(sitk.GetImageFromArray(std, isVector=False),
                                 img)
    sitk.WriteImage(std_image, std_filename)

    tSNR = np.divide(mean, std)
    tSNR[np.isnan(tSNR)] = 0
    tSNR_filename = os.path.abspath('tSNR.nii.gz')
    tSNR_image = copyInfo_3DImage(sitk.GetImageFromArray(tSNR, isVector=False),
                                  img)
    sitk.WriteImage(tSNR_image, tSNR_filename)

    axes[0, 1].set_title('Temporal STD', fontsize=30, color='white')
    std = std.flatten()
    std.sort()
    std_vmax = std[int(len(std) * 0.95)]
    plot_3d(axes[:, 1],
            std_image,
            fig=fig,
            vmin=0,
            vmax=std_vmax,
            cmap='inferno',
            cbar=True)
    axes[0, 2].set_title('Temporal SNR', fontsize=30, color='white')
    plot_3d(axes[:, 2],
            tSNR_image,
            fig=fig,
            vmin=0,
            vmax=tSNR.max(),
            cmap='Spectral',
            cbar=True)

    fig.savefig(f'{prefix}_temporal_features.png', bbox_inches='tight')

    return std_filename, tSNR_filename