예제 #1
0
def plot_roi_tracking(roi_mask_df, figpar, title=None):
    """
    plot_roi_tracking(roi_mask_df, figpar)
    
    Plots ROI tracking examples, for different session permutations, and union 
    across permutations.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)
            - "union_n_conflicts" (int): number of conflicts after union
            for "union", "fewest" and "most" tracked ROIs:
            - "{}_registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val),
                ordered by {}_sess_ns if "fewest" or "most"
            - "{}_n_tracked" (int): number of tracked ROIs
            for "fewest", "most" tracked ROIs:
            - "{}_sess_ns" (list): ordered session number 
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if len(roi_mask_df) != 1:
        raise ValueError("Expected only one row in roi_mask_df")
    roi_mask_row = roi_mask_df.loc[roi_mask_df.index[0]]

    columns = ["fewest", "most", "", "union"]

    figpar["init"]["ncols"] = len(columns)
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 5.05
    figpar["init"]["subplot_wid"] = 5.05
    figpar["init"]["gs"] = {"wspace": 0.06}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.905, 0.125, 0.06, 0.74] 

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)
    plot_util.remove_axis_marks(sub_ax_scale)
    sub_ax_scale.spines["left"].set_visible(True)

    if title is not None:
        fig.suptitle(title, y=1.05, weight="bold")

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    for c, column in enumerate(columns):
        sub_ax = ax[0, c]

        if c == 0:
            lp_col = plot_helper_fcts.get_line_plane_idxs(
                roi_mask_row["lines"], roi_mask_row["planes"]
                )[2]

            lp_name = plot_helper_fcts.get_line_plane_name(
                roi_mask_row["lines"], roi_mask_row["planes"]
                )
            sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
            log_info = f"Conflicts and matches for a {lp_name} example:"

        if column == "":
            sub_ax.set_axis_off()
            subplot_title = \
                "     Union - conflicts\n...   ====================>"
            sub_ax.set_title(subplot_title, fontweight="bold", y=0.5)
            continue
        else:
            plot_util.remove_axis_marks(sub_ax)
            for spine in ["right", "left", "top", "bottom"]:
                sub_ax.spines[spine].set_visible(True)

        if column in ["fewest", "most"]:
            y = 1.01
            ord_sess_ns = roi_mask_row[f"{column}_sess_ns"]
            ord_sess_ns_str = ", ".join([str(n) for n in ord_sess_ns])

            n_matches = int(roi_mask_row[f"{column}_n_tracked"])
            subplot_title = f"{n_matches} matches\n(sess {ord_sess_ns_str})"
            log_info = (f"{log_info}\n{TAB}"
                f"{column.capitalize()} matches (sess {ord_sess_ns_str}): "
                f"{n_matches}")
        
        elif column == "union":
            y = 1.04
            ord_sess_ns = roi_mask_row["sess_ns"]
            n_union = int(roi_mask_row[f"{column}_n_tracked"])
            n_conflicts = int(roi_mask_row[f"{column}_n_conflicts"])
            n_matches = n_union - n_conflicts

            subplot_title = f"{n_matches} matches"
            log_info = (f"{log_info}\n{TAB}"
                "Union - conflicts: "
                f"{n_union} - {n_conflicts} = {n_matches} matches"
                )

        sub_ax.set_title(subplot_title, fontweight="bold", y=y)

        roi_masks = create_sess_roi_masks(
            roi_mask_row, 
            mask_key=f"{column}_registered_roi_mask_idxs"
            )
        
        for sess_n in roi_mask_row["sess_ns"]:
            col = sess_cols[int(sess_n)]
            s = ord_sess_ns.index(sess_n)
            add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha)

    # add scale marker
    hei_len = roi_mask_row["roi_mask_shapes"][1]
    add_scale_marker(
        sub_ax_scale, side_len=hei_len, ori="vertical", quadrant=3, fontsize=20
        )

    logger.info(log_info, extra={"spacing": "\n"})

    # add legend
    add_sess_col_leg(
        ax[0, columns.index("")], 
        sess_cols, 
        bbox_to_anchor=(0.67, 0.3), 
        alpha=alpha
        )

    return ax
예제 #2
0
def plot_roi_masks_overlayed(roi_mask_df, figpar, title=None):
    """
    plot_roi_masks_overlayed(roi_mask_df, figpar)

    Plots ROI masks overlayed across sessions, optionally cropped.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)
            
            and optionally, if cropping:
            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]

        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    crop = "crop_fact" in roi_mask_df.columns

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 5.2
    figpar["init"]["subplot_wid"] = 5.2
    figpar["init"]["gs"] = {"wspace": 0.03, "hspace": 0.32}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.885, 0.11, 0.1, 0.33]
    if crop: # move to the left
        new_axis_coords[0] = 0.04

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)
    plot_util.remove_axis_marks(sub_ax_scale)
    sub_ax_scale.spines["left"].set_visible(True)

    if title is not None:
        fig.suptitle(title, y=0.95, weight="bold")

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    hei_lens = []
    for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]
        
        roi_masks = create_sess_roi_masks(lp_row, crop=crop)
        hei_lens.append(roi_masks.shape[1])

        for s, sess_n in enumerate(lp_row["sess_ns"]):
            col = sess_cols[int(sess_n)]
            add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha)

    # add legend
    add_sess_col_leg(
        ax[0, 1], sess_cols, bbox_to_anchor=(0.7, -0.01), alpha=alpha
        )

    # add scale marker
    hei_lens = np.unique(hei_lens)
    if len(hei_lens) != 1:
        raise NotImplementedError(
            "Adding scale bar not implemented if ROI mask image heights are "
            "different for different planes."
            )
    quadrant = 1 if crop else 3
    add_scale_marker(
        sub_ax_scale, side_len=hei_lens[0], ori="vertical", quadrant=quadrant, 
        fontsize=22
        )
 
    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, ylab="", kind="map")

    return ax
예제 #3
0
def plot_imaging_planes(imaging_plane_df, figpar, title=None):
    """
    plot_imaging_planes(imaging_plane_df, figpar)
    
    Plots imaging planes.

    Required args:
        - imaging_plane_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 2.4
    figpar["init"]["subplot_wid"] = 2.4
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.91, 0.115, 0.15, 0.34]

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)
    plot_util.remove_axis_marks(sub_ax_scale)
    sub_ax_scale.spines["left"].set_visible(True)

    if title is not None:
        fig.suptitle(title, y=1, weight="bold")

    hei_lens = []
    raster_zorder = -12
    for (line, plane), lp_mask_df in imaging_plane_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]
        
        # add projection
        imaging_plane = np.asarray(lp_row["max_projections"])
        hei_lens.append(imaging_plane.shape[0])
        add_imaging_plane(sub_ax, imaging_plane, alpha=0.98, 
            zorder=raster_zorder - 1
            )
        
    # add scale marker
    hei_lens = np.unique(hei_lens)
    if len(hei_lens) != 1:
        raise NotImplementedError(
            "Adding scale bar not implemented if ROI mask image heights are "
            "different for different planes."
            )
    add_scale_marker(
        sub_ax_scale, side_len=hei_lens[0], ori="vertical", quadrant=3, 
        fontsize=16
        )
 
    logger.info("Rasterizing imaging plane images...", extra={"spacing": TAB})
    for sub_ax in ax.reshape(-1):
        sub_ax.set_rasterization_zorder(raster_zorder)

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, ylab="", kind="map")
    for sub_ax in ax.reshape(-1):
        plot_util.remove_axis_marks(sub_ax)

    return ax
예제 #4
0
def plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar, title=None):
    """
    plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar)

    Plots ROI mask contours overlayed over imaging planes, and ROI masks 
    overlayed over each other across sessions.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 

            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_idxs" (list): list of mask indices for each session, 
                and each ROI (sess x (ROI, hei, wid) x val) (not registered)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)

            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]

        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    n_lines = len(roi_mask_df["lines"].unique())
    n_planes = len(roi_mask_df["planes"].unique())

    sess_cols = get_sess_cols(roi_mask_df)
    n_sess = len(sess_cols)
    n_cols = n_sess * n_lines

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 2.3
    figpar["init"]["subplot_wid"] = 2.3
    figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.2}
    figpar["init"]["ncols"] = n_cols

    fig, ax = plot_util.init_fig(n_cols * n_planes * 2, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.93, weight="bold")

    crop = "crop_fact" in roi_mask_df.columns

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    raster_zorder = -12

    for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        lp_col = plot_helper_fcts.get_line_plane_idxs(line, plane)[2]
        lp_name = plot_helper_fcts.get_line_plane_name(line, plane)

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]

        # identify subplots
        base_row = (pl % n_planes) * n_planes
        base_col = (li % n_lines) * n_lines

        ax_grp = ax[base_row : base_row + 2, base_col : base_col + n_sess + 1]

        # add imaging planes and masks
        imaging_planes = add_proj_and_roi_masks(
            ax_grp, lp_row, sess_cols, crop=crop, alpha=alpha, 
            proj_zorder=raster_zorder - 1
            )

        # add markings
        shared_row = base_row + 1
        shared_col = base_col + int((n_sess - 1) // 2)
        shared_sub_ax = ax[shared_row, shared_col]

        if shared_col == 0:
            shared_sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
        else:
            lp_sub_ax = ax[shared_row, 0]
            lp_sub_ax.set_xlim([0, 1])
            lp_sub_ax.set_ylim([0, 1])
            lp_sub_ax.text(
                0.5, 0.5, lp_name, fontweight="bold", color=lp_col, 
                ha="center", va="center", fontsize="x-large"
                )

        # add scale bar
        if n_sess < 2:
            raise NotImplementedError(
                "Scale bar placement not implemented for fewer than 2 "
                "sessions."
                )
        scale_ax = ax[shared_row, -1]
        wid_len = imaging_planes[0].shape[-1]
        add_scale_marker(
            scale_ax, side_len=wid_len, ori="horizontal", quadrant=1, 
            fontsize=20
            )

    logger.info("Rasterizing imaging plane images...", extra={"spacing": TAB})
    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):
            sub_ax = ax[i, j]
            plot_util.remove_axis_marks(sub_ax)
            if not(i % 2):
                sub_ax.set_rasterization_zorder(raster_zorder)

    # add legend
    if n_sess < 2:
        raise NotImplementedError(
            "Legend placement not implemented for fewer than 2 sessions."
            )
    add_sess_col_leg(
        ax[-1, -1], sess_cols, bbox_to_anchor=(1, 0.6), alpha=alpha, 
        fontsize="small"
        )

    return ax
예제 #5
0
def add_scale_marker(sub_ax, side_len=512, ori="horizontal", quadrant=1, 
                     fontsize=20):
    """
    add_scale_marker(sub_ax)

    Adds a scale marker and length in um to the subplot.

    Required args:
        - sub_ax (plt Axis subplot): 
            subplot

    Optional args:
        - side_len (int):
            length in pixels of the subplot side 
            (x axis if ori is "horizontal", and y axis if ori is "vertical")
            default: 512
        - ori (str):
            scale marker orientation ("horizontal" or "vertical")
            default: "horizontal"
        - quadrant (int):
            subplot quadrant in the corner of which to plot scale marker
            default: 1
        - fontsize (int):
            font size for scale length text
            default: 20
    """

    side_len_um = side_len * UM_PER_PIX
    half_len_um = side_len_um / 2
    
    if half_len_um >= 25:
        i = np.log(half_len_um / 25) // np.log(2)
        bar_len_um = int(25 * (2 ** i))
    else:
        i = np.log(half_len_um) / np.log(2)
        bar_len_um = 2 ** i
        if i >= 1:
            bar_len_um = int(bar_len_um)

    line_kwargs = {
        "color"         : "black",
        "lw"            : 4,
        "solid_capstyle": "butt",
    }

    if quadrant not in [1, 2, 3, 4]:
        gen_util.accepted_values_error("quadrant", quadrant, [1, 2, 3, 4])
    
    text_va = "center"
    if ori == "horizontal":
        axis_width_pts = sub_ax.get_window_extent().width

        sub_ax.set_xlim([0, side_len_um])
        sub_ax.set_ylim([0, 1])

        if quadrant in [1, 2]: # top
            y_coord = 0.95
            text_y = 0.8
        else:
            y_coord = 0.05
            text_y = 0.2
        
        if quadrant in [1, 4]: # right
            spine_width = sub_ax.spines["right"].get_linewidth()
            adj_um = spine_width / axis_width_pts * side_len_um
            xs = [side_len_um - bar_len_um - adj_um, side_len_um - adj_um]
            text_x = xs[-1]
            text_ha = "right"
        else:
            spine_width = sub_ax.spines["left"].get_linewidth()
            adj_um = spine_width / axis_width_pts * side_len_um
            xs = [adj_um, bar_len_um + adj_um]
            text_x = xs[0]
            text_ha = "left"

        sub_ax.plot(xs, [y_coord, y_coord], **line_kwargs)

    elif ori == "vertical":
        axis_height_pts = sub_ax.get_window_extent().height

        sub_ax.set_ylim([0, side_len_um])
        sub_ax.set_xlim([0, 1])

        if quadrant in [1, 2]: # top
            spine_height = sub_ax.spines["top"].get_linewidth()
            adj_um = spine_height / axis_height_pts * side_len_um
            ys = [side_len_um - bar_len_um - adj_um, side_len_um - adj_um]
            text_y = ys[-1]
            text_va = "top"
        else:
            spine_height = sub_ax.spines["bottom"].get_linewidth()
            adj_um = spine_height / axis_height_pts * side_len_um
            ys = [adj_um, bar_len_um + adj_um]
            text_y = ys[0]
            text_va = "bottom"
        
        if quadrant in [1, 4]: # right
            x_coord = 0.95
            text_x = 0.85
            text_ha = "right"
        else:
            x_coord = 0.05
            text_x = 0.1
            text_ha = "left"

        sub_ax.plot([x_coord, x_coord], ys, **line_kwargs)

    else:
        gen_util.accepted_values_error("ori", ori, ["horizontal", "vertical"])

    mu = u"\u03BC"
    sub_ax.text(
        text_x, text_y, r"{} {}m".format(bar_len_um, mu), 
        ha=text_ha, va=text_va, fontsize=fontsize, fontweight="bold",
        )

    plot_util.remove_axis_marks(sub_ax)

    return