コード例 #1
0
    def init_layout(self,
                    view_set=cfg.defacing_view_set,
                    num_rows_per_view=cfg.defacing_num_rows_per_view,
                    num_slices_per_view=cfg.defacing_num_slices_per_view,
                    padding=cfg.default_padding):
        """initializes the layout"""

        plt.style.use('dark_background')

        # vmin/vmax are controlled, because we rescale all to [0, 1]
        self.display_params = dict(interpolation='none',
                                   aspect='equal',
                                   origin='lower',
                                   cmap='gray',
                                   vmin=0.0,
                                   vmax=1.0)
        self.figsize = cfg.default_review_figsize

        self.collage = Collage(view_set=view_set,
                               num_slices=num_slices_per_view,
                               num_rows=num_rows_per_view,
                               display_params=self.display_params,
                               bounding_rect=cfg.bbox_defacing_MRI_review,
                               figsize=self.figsize)
        self.fig = self.collage.fig
        self.fig.canvas.set_window_title('VisualQC defacing : {} {} '
                                         ''.format(self.in_dir,
                                                   self.defaced_name))

        self.padding = padding
コード例 #2
0
ファイル: test_mrivis.py プロジェクト: snehashis1997/mrivis
def test_collage_class():

    img_path = pjoin(base_dir, '3569_bl_PPMI.nii')
    img = read_image(img_path, None)
    scaled = scale_0to1(img)
    c = Collage(num_slices=15, view_set=(0, 1), num_rows=3)

    try:
        c.attach(scaled)
    except:
        raise ValueError('Attach does not work')

    try:
        c.transform_and_attach(scaled, np.square)
    except:
        raise ValueError('transform_and_attach does not work')

    try:
        print(c)
    except:
        raise ValueError('repr implementation failed')
コード例 #3
0
class RatingWorkflowDefacing(BaseWorkflowVisualQC, ABC):
    """Rating worklfow for defaced MRI scans"""
    def __init__(self,
                 id_list,
                 images_for_id,
                 in_dir,
                 out_dir,
                 defaced_name,
                 mri_name,
                 render_name,
                 issue_list=cfg.defacing_default_issue_list,
                 vis_type='defacing'):
        """Constructor"""

        super().__init__(
            id_list,
            in_dir,
            out_dir,
            show_unit_id=False,  # preventing bias/batch-effects
            outlier_method=None,
            outlier_fraction=None,
            outlier_feat_types=None,
            disable_outlier_detection=None)

        self.vis_type = vis_type
        self.issue_list = issue_list
        self.defaced_name = defaced_name
        self.mri_name = mri_name
        self.render_name = render_name
        self.images_for_id = images_for_id

        self.expt_id = 'rate_defaced_mri_{}'.format(self.defaced_name)
        self.suffix = self.expt_id
        self.current_alert_msg = None

        self.init_layout()

    def preprocess(self):
        """Preprocessing if necessary."""

        pass

    def init_layout(self,
                    view_set=cfg.defacing_view_set,
                    num_rows_per_view=cfg.defacing_num_rows_per_view,
                    num_slices_per_view=cfg.defacing_num_slices_per_view,
                    padding=cfg.default_padding):
        """initializes the layout"""

        plt.style.use('dark_background')

        # vmin/vmax are controlled, because we rescale all to [0, 1]
        self.display_params = dict(interpolation='none',
                                   aspect='equal',
                                   origin='lower',
                                   cmap='gray',
                                   vmin=0.0,
                                   vmax=1.0)
        self.figsize = cfg.default_review_figsize

        self.collage = Collage(view_set=view_set,
                               num_slices=num_slices_per_view,
                               num_rows=num_rows_per_view,
                               display_params=self.display_params,
                               bounding_rect=cfg.bbox_defacing_MRI_review,
                               figsize=self.figsize)
        self.fig = self.collage.fig
        self.fig.canvas.set_window_title('VisualQC defacing : {} {} '
                                         ''.format(self.in_dir,
                                                   self.defaced_name))

        self.padding = padding

    def prepare_UI(self):
        """Main method to run the entire workflow"""

        self.open_figure()
        self.add_UI()

    def open_figure(self):
        """Creates the master figure to show everything in."""

        plt.show(block=False)

    def add_UI(self):
        """Adds the review UI with defaults"""

        # 2 keys for same combination exist to account for time delays in key presses
        map_key_to_callback = {
            'alt+b': self.show_defaced,
            'b+alt': self.show_defaced,
            'alt+o': self.show_original,
            'o+alt': self.show_original,
            'alt+m': self.show_mixed,
            'm+alt': self.show_mixed,
        }
        self.UI = DefacingInterface(
            self.collage.fig,
            self.collage.flat_grid,
            self.issue_list,
            next_button_callback=self.next,
            quit_button_callback=self.quit,
            processing_choice_callback=self.process_and_display,
            map_key_to_callback=map_key_to_callback)

        # connecting callbacks
        self.con_id_click = self.fig.canvas.mpl_connect(
            'button_press_event', self.UI.on_mouse)
        self.con_id_keybd = self.fig.canvas.mpl_connect(
            'key_press_event', self.UI.on_keyboard)
        # con_id_scroll = self.fig.canvas.mpl_connect('scroll_event', self.UI.on_scroll)

        self.fig.set_size_inches(self.figsize)

    def load_unit(self, unit_id):
        """Loads the image data for display."""

        # starting fresh
        for attr in ('defaced_img', 'orig_img', 'render_img'):
            if hasattr(self, attr):
                delattr(self, attr)

        self.defaced_img = read_image(self.images_for_id[unit_id]['defaced'],
                                      error_msg='defaced mri')
        self.orig_img = read_image(self.images_for_id[unit_id]['original'],
                                   error_msg='T1 mri')

        self.render_img_list = list()
        for rimg_path in self.images_for_id[unit_id]['render']:
            try:
                self.render_img_list.append(imread(rimg_path))
            except:
                raise IOError('Unable to read the 3D rendered image @\n {}'
                              ''.format(rimg_path))

        # crop, trim, and rescale
        self.defaced_img = rescale_without_outliers(
            self.defaced_img,
            padding=self.padding,
            trim_percentile=cfg.defacing_trim_percentile)
        self.orig_img = rescale_without_outliers(
            self.orig_img,
            padding=self.padding,
            trim_percentile=cfg.defacing_trim_percentile)
        self.currently_showing = None

        skip_subject = False
        if np.count_nonzero(self.defaced_img) == 0 or \
            np.count_nonzero(self.orig_img) == 0:
            skip_subject = True
            print('Defaced or original MR image is empty!')

        self.slice_picker = SlicePicker(self.orig_img,
                                        view_set=self.collage.view_set,
                                        num_slices=self.collage.num_slices,
                                        sampler=cfg.defacing_slice_locations)

        # # where to save the visualization to
        # out_vis_path = pjoin(self.out_dir,
        #   'visual_qc_{}_{}'.format(self.vis_type, unit_id))

        return skip_subject

    def process_and_display(self, user_choice):
        """Updates the display after applying the chosen method."""

        if user_choice in ('Defaced only', ):
            self.show_defaced()
        elif user_choice in ('Original only', ):
            self.show_original()
        elif user_choice in ('Mixed', 'Fused'):
            self.show_mixed()
        else:
            print('Chosen option seems to be not implemented!')

    def display_unit(self):
        """Adds slice collage to the given axes"""

        self.show_renders()
        self.show_mr_images()

    def show_renders(self):
        """Show all the rendered images"""

        num_cells = len(self.render_img_list)
        cell_extents = compute_cell_extents_grid(
            cfg.bbox_defacing_render_review,
            num_rows=cfg.defacing_num_rows_renders,
            num_cols=num_cells)

        self.ax_render = list()
        for img, ext in zip(self.render_img_list, cell_extents):
            ax = self.fig.add_axes(ext, frameon=False)
            ax.set_axis_off()
            ax.imshow(img)
            ax.set_visible(True)
            self.ax_render.append(ax)

    def show_defaced(self):
        """Show defaced only"""

        self.show_mr_images(vis_type='defaced')

    def show_original(self):
        """Show original only"""

        self.show_mr_images(vis_type='original')

    def show_mixed(self):
        """Show mixed"""

        self.show_mr_images(vis_type='mixed')

    def show_mr_images(self, vis_type='mixed'):
        """Generic router"""

        self.collage.clear()

        ax_counter = 0
        for df, orig in self.slice_picker.get_slices_multi(
            [self.defaced_img, self.orig_img]):

            ax = self.collage.flat_grid[ax_counter]
            if vis_type in ('mixed', ):
                #final_slice = mix_color(orig, df)
                red = 0.9 * orig
                grn = 1.0 * df
                blu = np.zeros_like(orig)
                ax.imshow(np.stack((red, grn, blu), axis=2),
                          **self.display_params)
            elif vis_type in ('defaced', ):
                ax.imshow(df, **self.display_params)
            elif vis_type in ('original', ):
                ax.imshow(orig, **self.display_params)
            else:
                raise ValueError('Invalid vis_type. Must be either mixed, '
                                 'defaced, or original')
            ax_counter += 1

        self.collage.show()

    def add_alerts(self):
        pass

    def cleanup(self):
        """Cleanup before exit"""

        # save ratings
        self.save_ratings()

        self.fig.canvas.mpl_disconnect(self.con_id_click)
        self.fig.canvas.mpl_disconnect(self.con_id_keybd)
        plt.close('all')
コード例 #4
0
class RatingWorkflowT1(BaseWorkflowVisualQC, ABC):
    """
    Rating workflow without any overlay.
    """

    def __init__(self,
                 id_list,
                 in_dir,
                 out_dir,
                 issue_list,
                 mri_name,
                 in_dir_type,
                 images_for_id,
                 outlier_method, outlier_fraction,
                 outlier_feat_types, disable_outlier_detection,
                 prepare_first,
                 vis_type,
                 views, num_slices_per_view, num_rows_per_view):
        """Constructor"""

        super().__init__(id_list, in_dir, out_dir,
                         outlier_method, outlier_fraction,
                         outlier_feat_types, disable_outlier_detection)

        self.vis_type = vis_type
        self.issue_list = issue_list
        self.mri_name = mri_name
        self.in_dir_type = in_dir_type
        self.images_for_id = images_for_id
        self.expt_id = 'rate_mri_{}'.format(self.mri_name)
        self.suffix = self.expt_id
        self.current_alert_msg = None
        self.prepare_first = prepare_first

        self.init_layout(views, num_rows_per_view, num_slices_per_view)
        self.init_getters()

    def preprocess(self):
        """
        Preprocess the input data
            e.g. compute features, make complex visualizations etc.
            before starting the review process.
        """

        if not self.disable_outlier_detection:
            print('Preprocessing data - please wait .. '
                  '\n\t(or contemplate the vastness of universe! )')
            self.extract_features()
        self.detect_outliers()

        # no complex vis to generate - skipping

    def prepare_UI(self):
        """Main method to run the entire workflow"""

        self.open_figure()
        self.add_UI()
        self.add_histogram_panel()

    def init_layout(self, views, num_rows_per_view,
                    num_slices_per_view, padding=cfg.default_padding):

        plt.style.use('dark_background')

        # vmin/vmax are controlled, because we rescale all to [0, 1]
        self.display_params = dict(interpolation='none', aspect='equal',
                                   origin='lower', cmap='gray', vmin=0.0, vmax=1.0)
        self.figsize = cfg.default_review_figsize

        self.collage = Collage(view_set=views,
                               num_slices=num_slices_per_view,
                               num_rows=num_rows_per_view,
                               display_params=self.display_params,
                               bounding_rect=cfg.bounding_box_review,
                               figsize=self.figsize)
        self.fig = self.collage.fig
        self.fig.canvas.set_window_title('VisualQC T1 MRI : {} {} '
                                         ''.format(self.in_dir, self.mri_name))

        self.padding = padding

    def init_getters(self):
        """Initializes the getters methods for input paths and feature readers."""

        from visualqc.features import extract_T1_features
        self.feature_extractor = extract_T1_features

        if self.vis_type is not None and (
            self.vis_type in cfg.freesurfer_vis_types or self.in_dir_type in [
            'freesurfer', ]):
            self.path_getter_inputs = lambda sub_id: realpath(
                pjoin(self.in_dir, sub_id, 'mri', self.mri_name))
        else:
            if self.in_dir_type.upper() in ('BIDS', ):
                self.path_getter_inputs = lambda sub_id: self.images_for_id[
                    sub_id]['image']
            else:
                self.path_getter_inputs = lambda sub_id: realpath(
                    pjoin(self.in_dir, sub_id, self.mri_name))

    def open_figure(self):
        """Creates the master figure to show everything in."""

        plt.show(block=False)

    def add_UI(self):
        """Adds the review UI with defaults"""

        # two keys for same combinations exist to account for time delays in key presses
        map_key_to_callback = {'alt+s': self.show_saturated,
                               's+alt': self.show_saturated,
                               'alt+b': self.show_background_only,
                               'b+alt': self.show_background_only,
                               'alt+t': self.show_tails_trimmed,
                               't+alt': self.show_tails_trimmed,
                               'alt+o': self.show_original,
                               'o+alt': self.show_original}
        self.UI = T1MriInterface(self.collage.fig, self.collage.flat_grid,
                                 self.issue_list,
                                 next_button_callback=self.next,
                                 quit_button_callback=self.quit,
                                 processing_choice_callback=self.process_and_display,
                                 map_key_to_callback=map_key_to_callback)

        # connecting callbacks
        self.con_id_click = self.fig.canvas.mpl_connect('button_press_event',
                                                        self.UI.on_mouse)
        self.con_id_keybd = self.fig.canvas.mpl_connect('key_press_event',
                                                        self.UI.on_keyboard)
        # con_id_scroll = self.fig.canvas.mpl_connect('scroll_event', self.UI.on_scroll)

        self.fig.set_size_inches(self.figsize)

    def add_histogram_panel(self):
        """Extra axis for histogram"""

        self.ax_hist = plt.axes(cfg.position_histogram_t1_mri)
        self.ax_hist.set_xticks(cfg.xticks_histogram_t1_mri)
        self.ax_hist.set_yticks([])
        self.ax_hist.set_autoscaley_on(True)
        self.ax_hist.set_prop_cycle('color', cfg.color_histogram_t1_mri)
        self.ax_hist.set_title(cfg.title_histogram_t1_mri, fontsize='small')

    def update_histogram(self, img):
        """Updates histogram with current image data"""

        nonzero_values = img.ravel()[np.flatnonzero(img)]
        _, _, patches_hist = self.ax_hist.hist(nonzero_values, density=True,
                                               bins=cfg.num_bins_histogram_display)
        self.ax_hist.relim(visible_only=True)
        self.ax_hist.autoscale_view(scalex=False)  # xlim fixed to [0, 1]
        self.UI.data_handles.extend(patches_hist)

    def update_alerts(self):
        """Keeps a box, initially invisible."""

        if self.current_alert_msg is not None:
            h_alert_text = self.fig.text(cfg.position_outlier_alert_t1_mri[0],
                                         cfg.position_outlier_alert_t1_mri[1],
                                         self.current_alert_msg, **cfg.alert_text_props)
            # adding it to list of elements to cleared when advancing to next subject
            self.UI.data_handles.append(h_alert_text)

    def add_alerts(self):
        """Brings up an alert if subject id is detected to be an outlier."""

        flagged_as_outlier = self.current_unit_id in self.by_sample
        if flagged_as_outlier:
            alerts_list = self.by_sample.get(self.current_unit_id,
                                             None)  # None, if id not in dict
            print('\n\tFlagged as a possible outlier by these measures:\n\t\t{}'
                  ''.format('\t'.join(alerts_list)))

            strings_to_show = ['Flagged as an outlier:', ] + alerts_list
            self.current_alert_msg = '\n'.join(strings_to_show)
            self.update_alerts()
        else:
            self.current_alert_msg = None

    def load_unit(self, unit_id):
        """Loads the image data for display."""

        # starting fresh
        for attr in ('current_img_raw', 'current_img',
                     'saturated_img', 'tails_trimmed_img', 'background_img'):
            if hasattr(self, attr):
                delattr(self, attr)

        t1_mri_path = self.path_getter_inputs(unit_id)
        self.current_img_raw = read_image(t1_mri_path, error_msg='T1 mri')
        # crop and rescale
        self.current_img = scale_0to1(crop_image(self.current_img_raw, self.padding))
        self.currently_showing = None

        skip_subject = False
        if np.count_nonzero(self.current_img) == 0:
            skip_subject = True
            print('MR image is empty!')

        # # where to save the visualization to
        # out_vis_path = pjoin(self.out_dir, 'visual_qc_{}_{}'.format(self.vis_type, unit_id))

        return skip_subject

    def display_unit(self):
        """Adds slice collage to the given axes"""

        # showing the collage
        self.collage.attach(self.current_img)
        # updating histogram
        self.update_histogram(self.current_img)

    def process_and_display(self, user_choice):
        """Updates the display after applying the chosen method."""

        if user_choice in ('Saturate',):
            self.show_saturated(no_toggle=True)
        elif user_choice in ('Background only',):
            self.show_background_only(no_toggle=True)
        elif user_choice in ('Tails_trimmed', 'Tails trimmed'):
            self.show_tails_trimmed(no_toggle=True)
        elif user_choice in ('Original',):
            self.show_original()
        else:
            print('Chosen option seems to be not implemented!')

    def show_saturated(self, no_toggle=False):
        """Callback for ghosting specific review"""

        if not self.currently_showing in ['saturated', ] or no_toggle:
            if not hasattr(self, 'saturated_img'):
                self.saturated_img = saturate_brighter_intensities(
                    self.current_img, percentile=cfg.saturate_perc_t1)
            self.collage.attach(self.saturated_img)
            self.currently_showing = 'saturated'
        else:
            self.show_original()

    def show_background_only(self, no_toggle=False):
        """Callback for ghosting specific review"""

        if not self.currently_showing in ['Background only', ] or no_toggle:
            self._compute_background()
            self.collage.attach(self.background_img)
            self.currently_showing = 'Background only'
        else:
            self.show_original()

    def _compute_background(self):
        """Computes the background image for the current image."""

        if not hasattr(self, 'background_img'):
            # need to scale the mask, as Collage class does NOT automatically rescale
            self.foreground_mask = mask_image(self.current_img, out_dtype=bool)
            temp_background_img = np.copy(self.current_img)
            temp_background_img[self.foreground_mask] = 0.0
            self.background_img = scale_0to1(temp_background_img,
                                             exclude_outliers_below=1,
                                             exclude_outliers_above=1)

    def show_tails_trimmed(self, no_toggle=False):
        """Callback for ghosting specific review"""

        if not self.currently_showing in ['tails_trimmed', ] or no_toggle:
            if not hasattr(self, 'tails_trimmed_img'):
                self.tails_trimmed_img = scale_0to1(self.current_img,
                                                    exclude_outliers_below=1,
                                                    exclude_outliers_above=0.05)
            self.collage.attach(self.tails_trimmed_img)
            self.currently_showing = 'tails_trimmed'
        else:
            self.show_original()

    def show_original(self):
        """Show the original"""

        self.collage.attach(self.current_img)
        self.currently_showing = 'original'

    def cleanup(self):
        """Preparating for exit."""

        # save ratings before exiting
        self.save_ratings()

        self.fig.canvas.mpl_disconnect(self.con_id_click)
        self.fig.canvas.mpl_disconnect(self.con_id_keybd)
        plt.close('all')
コード例 #5
0
def checkerboard(
    img_spec1=None,
    img_spec2=None,
    patch_size=10,
    view_set=(0, 1, 2),
    num_slices=(10, ),
    num_rows=2,
    rescale_method='global',
    background_threshold=0.05,
    annot=None,
    padding=5,
    output_path=None,
    figsize=None,
):
    """
    Checkerboard mixer.

    Parameters
    ----------
    img_spec1 : str or nibabel image-like object
        MR image (or path to one) to be visualized

    img_spec2 : str or nibabel image-like object
        MR image (or path to one) to be visualized

    patch_size : int or list or (int, int) or None
        size of checker patch (either square or rectangular)
        If None, number of voxels/patch are chosen such that,
            there will be 7 patches through the width/height.

    view_set : iterable
        Integers specifying the dimensions to be visualized.
        Choices: one or more of (0, 1, 2) for a 3D image

    num_slices : int or iterable of size as view_set
        number of slices to be selected for each view
        Must be of the same length as view_set,
            each element specifying the number of slices for each dimension.
            If only one number is given, same number will be chosen for all dimensions.

    num_rows : int
        number of rows (top to bottom) per each of 3 dimensions

    rescale_method : bool or str or list or None
        Range to rescale the intensity values to
        Default: 'global', min and max values computed based on ranges from both images.
        If false or None, no rescaling is done (does not work yet).

    background_threshold : float or str
        A threshold value below which all the background voxels will be set to zero.
        Default : 0.05. Other option is a string specifying a percentile: '5%', '10%'.
        Specify None if you don't want any thresholding.

    annot : str
        Text to display to annotate the visualization

    padding : int
        number of voxels to pad around each panel.

    output_path : str
        path to save the generate collage to.

    figsize : list
        Size of figure in inches to be passed on to plt.figure() e.g. [12, 12] or [20, 20]

    Returns
    -------
    fig : figure handle
        handle to the collage figure generated.

    """

    img_one, img_two = _preprocess_images(img_spec1,
                                          img_spec2,
                                          rescale_method=rescale_method,
                                          bkground_thresh=background_threshold,
                                          padding=padding)

    display_params = dict(interpolation='none',
                          aspect='auto',
                          origin='lower',
                          cmap='gray',
                          vmin=0.0,
                          vmax=1.0)

    mixer = partial(_checker_mixer, checker_size=patch_size)
    collage = Collage(view_set=view_set,
                      num_slices=num_slices,
                      num_rows=num_rows,
                      figsize=figsize,
                      display_params=display_params)
    collage.transform_and_attach((img_one, img_two), func=mixer)
    collage.save(output_path=output_path, annot=annot)

    return collage