示例#1
0
    def display_unit(self):
        """Adds slice collage, with seg overlays on MRI in each panel."""

        if 'cortical' in self.vis_type:
            if not self.no_surface_vis and self.current_unit_id in self.surface_vis_paths:
                surf_paths = self.surface_vis_paths[
                    self.current_unit_id]  # is a dict of paths
                for sf_ax_index, ((hemi, view),
                                  spath) in enumerate(surf_paths.items()):
                    plt.sca(self.axes[sf_ax_index])
                    img = mpimg.imread(spath)
                    # img = crop_image(img)
                    h_surf = plt.imshow(img)
                    self.axes[sf_ax_index].text(0, 0,
                                                '{} {}'.format(hemi, view))
                    self.UI.data_handles.append(h_surf)
            else:
                msg = 'no surface visualizations\navailable or disabled'
                print('{} for {}'.format(msg, self.current_unit_id))
                self.axes[1].text(0.5, 0.5, msg)

        slices = pick_slices(self.current_seg, self.views,
                             self.num_slices_per_view)
        for vol_ax_index, (dim_index, slice_index) in enumerate(slices):
            panel_index = self.volumetric_start_index + vol_ax_index
            plt.sca(self.axes[panel_index])
            slice_mri = get_axis(self.current_t1_mri, dim_index, slice_index)
            slice_seg = get_axis(self.current_seg, dim_index, slice_index)

            mri_rgba = self.mri_mapper.to_rgba(slice_mri, alpha=self.alpha_mri)
            # self.h_images_mri[ax_index].set_data(mri_rgb)
            h_m = plt.imshow(mri_rgba,
                             interpolation='none',
                             aspect='equal',
                             origin='lower')
            self.UI.data_handles.append(h_m)

            if 'volumetric' in self.vis_type:
                seg_rgba = self.seg_mapper.to_rgba(slice_seg,
                                                   alpha=self.alpha_seg)
                # self.h_images_seg[ax_index].set_data(seg_rgb)
                h_seg = plt.imshow(seg_rgba,
                                   interpolation='none',
                                   aspect='equal',
                                   origin='lower')
                self.togglable_handles.append(h_seg)
                # self.UI.data_handles.append(h_seg)
                del seg_rgba
            elif 'contour' in self.vis_type:
                h_seg = self.plot_contours_in_slice(slice_seg,
                                                    self.axes[panel_index])
                for contours in h_seg:
                    self.togglable_handles.extend(contours.collections)
                    # for clearing upon review
                    self.UI.data_handles.extend(contours.collections)

            del slice_seg, slice_mri, mri_rgba

        self.update_histogram()
示例#2
0
    def mix_and_display(self):
        """Static mix and display."""

        # TODO maintain a dict mixed[vis_type] to do computation only once
        for ax_index, (dim_index, slice_index) in enumerate(self.slices):
            slice_one = get_axis(self.image_one, dim_index, slice_index)
            slice_two = get_axis(self.image_two, dim_index, slice_index)
            mixed_slice = self.mixer(slice_one, slice_two)
            # mixed_slice is already in RGB mode m x p x 3, so
            #   prev. cmap (gray) has no effect on color_mixed data
            self.h_images[ax_index].set_data(mixed_slice)
示例#3
0
    def attach_image_to_foreground_axes(self, image3d, cmap='gray'):
        """Attaches a given image to the foreground axes and bring it forth"""

        image3d = crop_image(image3d, self.padding)
        image3d = scale_0to1(image3d)
        slices = pick_slices(image3d, self.views, self.num_slices_per_view)
        for ax_index, (dim_index, slice_index) in enumerate(slices):
            slice_data = get_axis(image3d, dim_index, slice_index)
            self.images_fg[ax_index].set(data=slice_data, cmap=cmap)
        for ax in self.fg_axes:
            ax.set(visible=True, zorder=self.layer_order_zoomedin)
示例#4
0
    def show_image(self, img, annot=None):
        """Display the requested slices of an image on the existing axes."""

        for ax_index, (dim_index, slice_index) in enumerate(self.slices):
            self.h_images[ax_index].set_data(
                get_axis(img, dim_index, slice_index))

        if annot is not None:
            self._identify_foreground(annot)
        else:
            self.fg_annot_h.set_visible(False)
示例#5
0
    def display_unit(self):
        """Adds slice collage to the given axes"""

        # crop and rescale
        img = crop_image(self.current_img, self.padding)
        img = scale_0to1(img)

        # adding slices
        slices = pick_slices(img, self.views, self.num_slices_per_view)
        for ax_index, (dim_index, slice_index) in enumerate(slices):
            slice_data = get_axis(img, dim_index, slice_index)
            self.images[ax_index].set_data(slice_data)

        # updating histogram
        self.update_histogram(img)
示例#6
0
def overlay_images(qcw, mri, seg,
                   subject_id=None,
                   annot=None,
                   figsize=None,
                   padding=default_padding,
                   output_path=None):
    """Backend engine for overlaying a given seg on MRI with freesurfer label."""

    num_rows_per_view, num_slices_per_view, padding = check_params(qcw.num_rows, qcw.num_slices, padding)
    mri, seg = crop_to_seg_extents(mri, seg, padding)

    surf_vis = dict()  # empty - no vis to include
    # TODO broaden this to include subcortical structures as well
    if 'cortical' in qcw.vis_type:
        if qcw.in_dir is not None and subject_id is not None and qcw.out_dir is not None:
            surf_vis = make_vis_pial_surface(qcw.in_dir, subject_id, qcw.out_dir)
    num_surf_vis = len(surf_vis)

    # TODO calculation below is redundant, if surf vis does not fail
    # i.e. if num_surf_vis is fixed, no need to recompute for every subject
    num_views = len(qcw.views)
    num_rows = num_rows_per_view * num_views
    slices = pick_slices(seg, qcw.views, num_slices_per_view)
    num_volumetric_slices = len(slices)
    total_num_panels = num_volumetric_slices + num_surf_vis
    num_rows_for_surf_vis = 1 if num_surf_vis > 0 else 0
    num_rows = num_rows + num_rows_for_surf_vis
    num_cols = check_layout(total_num_panels, num_views, num_rows_per_view, num_rows_for_surf_vis)

    plt.style.use('dark_background')

    if figsize is None:
        # figsize = [min(15,4*num_rows), min(12,4*num_cols)] # max (15,12)
        figsize = [4 * num_rows, 2* num_cols]
    fig, ax = plt.subplots(num_rows, num_cols, figsize=figsize)

    display_params_mri = dict(interpolation='none', aspect='equal', origin='lower',
                              alpha=qcw.alpha_mri)
    display_params_seg = dict(interpolation='none', aspect='equal', origin='lower',
                              alpha=qcw.alpha_seg)

    normalize_labels = colors.Normalize(vmin=seg.min(), vmax=seg.max(), clip=True)
    fs_cmap = get_freesurfer_cmap(qcw.vis_type)
    seg_mapper = cm.ScalarMappable(norm=normalize_labels, cmap=fs_cmap)

    normalize_mri = colors.Normalize(vmin=mri.min(), vmax=mri.max(), clip=True)
    mri_mapper = cm.ScalarMappable(norm=normalize_mri, cmap='gray')

    # deciding colors for the whole image
    unique_labels = np.unique(seg)
    # removing background - 0 stays 0
    unique_labels = np.delete(unique_labels, 0)
    if len(unique_labels) == 1:
        color4label = [qcw.contour_color]
    else:
        color4label = seg_mapper.to_rgba(unique_labels)

    handles_seg = list()
    handles_mri = list()

    ax = ax.flatten()
    # display surfaces
    for sf_counter, ((hemi, view), spath) in enumerate(surf_vis.items()):
        plt.sca(ax[sf_counter])
        img = mpimg.imread(spath)
        # img = crop_image(img)
        plt.imshow(img)
        ax[sf_counter].text(0, 0, '{} {}'.format(hemi, view))
        plt.axis('off')

    # display slices
    for ax_counter, (dim_index, slice_num) in enumerate(slices):
        plt.sca(ax[ax_counter + num_surf_vis])

        slice_mri = get_axis(mri, dim_index, slice_num)
        slice_seg = get_axis(seg, dim_index, slice_num)

        # display MRI
        mri_rgb = mri_mapper.to_rgba(slice_mri)
        h_mri = plt.imshow(mri_rgb, **display_params_mri)

        if 'volumetric' in qcw.vis_type:
            seg_rgb = seg_mapper.to_rgba(slice_seg)
            h_seg = plt.imshow(seg_rgb, **display_params_seg)
        elif 'contour' in qcw.vis_type:
            h_seg = plot_contours_in_slice(slice_seg, unique_labels, color4label)

        plt.axis('off')

        # # encoding the souce of the object (image/line) being displayed
        # handle_seg.set_label('seg {} {}'.format(dim_index, slice_num))
        # handle_mri.set_label('mri {} {}'.format(dim_index, slice_num))

        handles_mri.append(h_mri)
        if len(h_seg) >= 1:
            handles_seg.extend(h_seg)
        else:
            handles_seg.append(h_seg)

    # hiding unused axes
    for ua in range(total_num_panels, len(ax)):
        ax[ua].set_visible(False)

    if annot is not None:
        h_annot = fig.suptitle(annot, **cfg.annot_text_props)
        h_annot.set_position(cfg.position_annot_text)

    fig.set_size_inches(figsize)

    if output_path is not None:
        # no space left unused
        plt.subplots_adjust(**cfg.no_blank_area)
        output_path = output_path.replace(' ', '_')
        layout_str = 'v{}_ns{}_{}x{}'.format(''.join([ str(v) for v in qcw.views]),num_slices_per_view,num_rows,num_cols)
        fig.savefig(output_path + '_{}.png'.format(layout_str), bbox_inches='tight')

    # leaving some space on the right for review elements
    plt.subplots_adjust(**cfg.review_area)

    return fig, handles_mri, handles_seg, figsize