示例#1
0
def add_USI_boxes(ax, chosen_rois_df, sorted_target_idxs): 
    """
    add_USI_boxes(ax, chosen_rois_df, sorted_target_idxs)

    Adds boxes with USI values to individual plots (e.g., trace plots).

    Required args:
        - ax (subplot array):
            pyplot axis array
        - chosen_rois_df (pd.DataFrame):
            chosen ROIs dataframe with, in addition to the basic sess_df 
            columns, "target_idxs".
        - sorted_target_idxs (list): 
            order in which different target_idxs should be added to subplots
    """
    props = dict(
        boxstyle="round", facecolor="white", edgecolor="black", alpha=0.5, 
        lw=1.5)

    for (line, plane), lp_df in chosen_rois_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for r, row_val in enumerate(sorted_target_idxs):
            rows = lp_df.loc[lp_df["target_idxs"] == row_val]
            if len(rows) == 0:
                continue
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected row_order instances to be unique per line/plane."
                    )

            row = rows.loc[rows.index[0]]
            sub_ax = ax[r + pl * len(sorted_target_idxs), li]

            # place a text box in upper left in axes coords
            sub_ax.text(0.1, 0.9, f"USI = {row['roi_idxs']:.2f}", 
                transform=sub_ax.transAxes, fontsize=15, va="center", 
                bbox=props)
示例#2
0
def plot_idx_correlations(idx_corr_df,
                          permpar,
                          figpar,
                          permute="sess",
                          corr_type="corr",
                          title=None,
                          small=True):
    """
    plot_idx_correlations(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlations across sessions.

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for session comparisons, e.g. 1v2
            - {}v{}{norm_str}_corrs (float): intersession ROI index correlations
            - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI 
                index correlation standard deviation
            - {}v{}_null_CIs (list): adjusted null CI for intersession ROI 
                index correlations
            - {}v{}_raw_p_vals (float): p-value for intersession correlations
            - {}v{}_p_vals (float): p-value for intersession correlations, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - corr_type (str):
            type of correlation run, i.e. "corr" or "R_sqr"
            default: "corr"
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, smaller subplots are plotted
            default: True

    Returns:
        - ax (2D array): 
            array of subplots
    """

    norm = False
    if permute in ["sess", "all"]:
        corr_type = f"diff_{corr_type}"
        if corr_type == "diff_corr":
            norm = True
            title = title.replace("Correlations", "Normalized correlations")

    norm_str = "_norm" if norm else ""

    sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm)
    n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2)  # multiple of 2

    figpar = sess_plot_util.fig_init_linpla(figpar,
                                            kind="reg",
                                            n_sub=int(n_pairs / 2))

    figpar["init"]["ncols"] = n_pairs
    figpar["init"]["sharey"] = "row"

    figpar["init"]["gs"] = {"hspace": 0.25}
    if small:
        figpar["init"]["subplot_wid"] = 2.7
        figpar["init"]["subplot_hei"] = 4.21
        figpar["init"]["gs"]["wspace"] = 0.2
    else:
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["subplot_hei"] = 4.71
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm)
    lines = [None, None]

    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
        lines[li] = line.split("-")[0].replace("23", "2/3")

        if len(lp_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        row = lp_df.loc[lp_df.index[0]]

        lp_sig_str = f"{linpla_name:6}:"
        for s, sess_pair in enumerate(sess_pairs):
            sub_ax = ax[pl, s]
            if s == 0:
                sub_ax.set_ylim(plane_pts[pl])

            col_base = f"{sess_pair[0]}v{sess_pair[1]}"

            CI = row[f"{col_base}_null_CIs"]
            extr = np.asarray([CI[0], CI[2]])
            plot_util.plot_CI(sub_ax,
                              extr,
                              med=CI[1],
                              x=li,
                              width=0.45,
                              med_rat=0.025)

            y = row[f"{col_base}{norm_str}_corrs"]

            err = row[f"{col_base}{norm_str}_corr_stds"]
            plot_util.plot_ufo(sub_ax,
                               x=li,
                               y=y,
                               err=err,
                               color=col,
                               capsize=8)

            # add significance markers
            p_val = row[f"{col_base}_p_vals"]
            side = np.sign(y - CI[1])
            sensitivity = misc_analys.get_sensitivity(permpar)
            sig_str = misc_analys.get_sig_symbol(p_val,
                                                 sensitivity=sensitivity,
                                                 side=side,
                                                 tails=permpar["tails"],
                                                 p_thresh=permpar["p_val"])

            if len(sig_str):
                high = np.max([CI[-1], y + err])
                plot_util.add_signif_mark(sub_ax,
                                          li,
                                          high,
                                          rel_y=0.1,
                                          color=col,
                                          fontsize=24,
                                          mark=sig_str)

            sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: "
            lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}"

        logger.info(lp_sig_str, extra={"spacing": TAB})

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         lines=["", ""],
                                         xticks=[0, 1],
                                         ylab="",
                                         xlab="",
                                         kind="traces")

    xs = np.arange(len(lines))
    pad_x = 0.6 * (xs[1] - xs[0])
    for row_n in range(len(ax)):
        for col_n in range(len(ax[row_n])):
            sub_ax = ax[row_n, col_n]
            sub_ax.tick_params(axis="x", which="both", bottom=False)
            sub_ax.set_xticks(xs)
            sub_ax.set_xticklabels(lines, weight="bold")
            sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x])

            sub_ax.set_ylim(plane_pts[row_n])
            sub_ax.set_yticks(plane_pts[row_n])

            if row_n == 0:
                if col_n < len(sess_pairs):
                    s1, s2 = sess_pairs[col_n]
                    sess_pair_title = f"Session {s1} v {s2}"
                    sub_ax.set_title(sess_pair_title,
                                     fontweight="bold",
                                     y=1.07)
                sub_ax.spines["bottom"].set_visible(True)

        plot_util.set_interm_ticks(ax[row_n],
                                   3,
                                   axis="y",
                                   weight="bold",
                                   share=False,
                                   update_ticks=True)

    return ax
示例#3
0
def get_idx_corr_ylims(idx_corr_df, norm=False):
    """
    get_idx_corr_ylims(idx_corr_df)

    Returns data edges (min and max data to be plotted).

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for session comparisons, e.g. 1v2
            - {}v{}{norm_str}_corrs (float): intersession ROI index 
                correlations
            - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI 
                index correlation standard deviation
            - {}v{}_null_CIs (list): adjusted null CI for normalized 
                intersession ROI index correlations
        
    Returns:
        - plane_pts (list):
            [low_pt, high_pt] for each plane, in plane order, based on plane 
            indices
    """

    sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm)

    norm_str = "_norm" if norm else ""

    plane_pts = []
    plane_idxs = []
    for plane, plane_df in idx_corr_df.groupby("planes"):
        _, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(plane=plane)
        plane_idxs.append(pl)

        low_pts, high_pts = [], []

        for sess_pair in sess_pairs:
            base = f"{sess_pair[0]}v{sess_pair[1]}"

            # get null_CIs low
            null_CI_low = np.min(
                [null_CI[0] for null_CI in plane_df[f"{base}_null_CIs"]])
            null_CI_high = np.max(
                [null_CI[2] for null_CI in plane_df[f"{base}_null_CIs"]])

            # get data
            data_low = (plane_df[f"{base}{norm_str}_corrs"] -
                        plane_df[f"{base}{norm_str}_corr_stds"]).min()

            data_high = (plane_df[f"{base}{norm_str}_corrs"] +
                         plane_df[f"{base}{norm_str}_corr_stds"]).max()

            low_pts.extend([null_CI_low, data_low])
            high_pts.extend([null_CI_high, data_high])

        low_pt = np.min(low_pts)
        high_pt = np.max(high_pts)

        pt_range = high_pt - low_pt
        low_pt -= pt_range / 10
        high_pt += pt_range / 10

        plane_pts.append([low_pt, high_pt])

    plane_pts = [plane_pts[i] for i in np.argsort(plane_idxs)]

    return plane_pts
示例#4
0
def plot_rand_corr_ex_data(idx_corr_norm_df, title=None):
    """
    plot_rand_corr_ex_data(idx_corr_norm_df)

    Plots example random correlation data in a scatterplot and histogram to 
    show how normalized residual correlations are calculated.

    Required args:
        - idx_corr_norm_df (pd.DataFrame):
            dataframe with one row for a line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for a specific session comparison, e.g. 1v2
            - {}v{}_corrs (float): unnormalized intersession ROI index 
                correlations
            - {}v{}_norm_corrs (float): normalized intersession ROI index 
                correlations
            - {}v{}_rand_ex_corrs (float): unnormalized intersession 
                ROI index correlations for an example of randomized data
            - {}v{}_rand_corr_meds (float): median of randomized correlations

            - {}v{}_corr_data (list): intersession values to correlate
            - {}v{}_rand_ex (list): intersession values for an example of 
                randomized data
            - {}v{}_rand_corrs_binned (list): binned random unnormalized 
                intersession ROI index correlations
            - {}v{}_rand_corrs_bin_edges (list): bins edges

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    plot_types = ["scatter", "hist"]
    fig, ax = plt.subplots(nrows=len(plot_types),
                           figsize=[8.7, 9.3],
                           gridspec_kw={"hspace": 0.7})

    if len(idx_corr_norm_df) != 1:
        raise ValueError("Expected idx_corr_norm_df to contain only one row.")

    sorted_pairs = get_sorted_sess_pairs(idx_corr_norm_df, norm=True)

    if len(sorted_pairs) != 1:
        raise RuntimeError(
            "Expected to find only one pair of sessions for which to plot data."
        )
    sess_pair = sorted_pairs[0]

    row = idx_corr_norm_df.loc[idx_corr_norm_df.index[0]]
    corr_name = f"{sess_pair[0]}v{sess_pair[1]}"
    _, _, col, _ = plot_helper_fcts.get_line_plane_idxs(
        row["lines"], row["planes"])

    if title is not None:
        fig.suptitle(title, y=0.95, weight="bold")

    # plot scatterplot
    scatt_ax = ax[0]
    plot_corr_ex_data_scatterplot(scatt_ax, row, corr_name=corr_name, col=col)

    # plot histogram
    hist_ax = ax[1]
    plot_corr_ex_data_histogram(hist_ax, row, corr_name=corr_name, col=col)

    plot_util.set_interm_ticks(ax,
                               n_ticks=4,
                               axis="y",
                               share=False,
                               fontweight="bold")

    return ax
示例#5
0
def plot_imaging_planes(imaging_plane_df, figpar, title=None):
    """
    plot_imaging_planes(imaging_plane_df, figpar)
    
    Plots imaging planes.

    Required args:
        - imaging_plane_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 2.4
    figpar["init"]["subplot_wid"] = 2.4
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.91, 0.115, 0.15, 0.34]

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)
    plot_util.remove_axis_marks(sub_ax_scale)
    sub_ax_scale.spines["left"].set_visible(True)

    if title is not None:
        fig.suptitle(title, y=1, weight="bold")

    hei_lens = []
    raster_zorder = -12
    for (line, plane), lp_mask_df in imaging_plane_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]
        
        # add projection
        imaging_plane = np.asarray(lp_row["max_projections"])
        hei_lens.append(imaging_plane.shape[0])
        add_imaging_plane(sub_ax, imaging_plane, alpha=0.98, 
            zorder=raster_zorder - 1
            )
        
    # add scale marker
    hei_lens = np.unique(hei_lens)
    if len(hei_lens) != 1:
        raise NotImplementedError(
            "Adding scale bar not implemented if ROI mask image heights are "
            "different for different planes."
            )
    add_scale_marker(
        sub_ax_scale, side_len=hei_lens[0], ori="vertical", quadrant=3, 
        fontsize=16
        )
 
    logger.info("Rasterizing imaging plane images...", extra={"spacing": TAB})
    for sub_ax in ax.reshape(-1):
        sub_ax.set_rasterization_zorder(raster_zorder)

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, ylab="", kind="map")
    for sub_ax in ax.reshape(-1):
        plot_util.remove_axis_marks(sub_ax)

    return ax
def plot_stim_data_df(stim_data_df, stimpar, permpar, figpar, pop_stats=True, 
                      title=None):
    """
    plot_stim_data_df(stim_data_df, stimpar, permpar, figpar)

    Plots stimulus comparison data.

    Required args:
        - stim_stats_df (pd.DataFrame):
            dataframe with one row per line/plane and one for all line/planes 
            together, and the basic sess_df columns, in addition to, 
            for each stimtype:
            - stimtype (list): absolute fractional change statistics (me, err)
            - raw_p_vals (float): uncorrected p-value for data differences 
                between stimulus types 
            - p_vals (float): p-value for data differences between stimulus 
                types, corrected for multiple comparisons and tails
        - stimpar (dict): 
            dictionary with keys of StimPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - pop_stats (bool):
            if True, analyses are run on population statistics, and not 
            individual tracked ROIs
            default: True
        - title (str):
            plot title
            default: None
    
    Returns:
        - ax (2D array): 
            array of subplots 
            (does not include added subplot for all line/plane data together)
    """

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")

    figpar["init"]["subplot_wid"] = 2.1
    figpar["init"]["subplot_hei"] = 4.2
    figpar["init"]["gs"] = {"hspace": 0.20, "wspace": 0.3}
    figpar["init"]["sharey"] = "row"
    
    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])
    fig.suptitle(title, y=0.98, weight="bold")

    sub_ax_all = fig.add_axes([1.05, 0.11, 0.3, 0.77])

    stimtypes = stimpar["stimtype"][:] # deep copy

    # indicate bootstrapped error with wider capsize
    capsize = 8 if pop_stats else 6

    lp_data = []
    cols = []
    for (line, plane), lp_df in stim_data_df.groupby(["lines", "planes"]):
        x = [0, 1]
        data = np.vstack(
            [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()]
            ).T
        y = data[0]
        err = data[1:]

        if line != "all" and plane != "all":
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            alpha = 0.5
            sub_ax = ax[pl, li]
            lp_data.append(y)
            cols.append(col)
        else:
            col = plot_helper_fcts.NEARBLACK
            dash = None
            alpha = 0.2
            sub_ax = sub_ax_all
            sub_ax.set_title("all", fontweight="bold")
    
        plot_util.plot_bars(
            sub_ax, x, y=y, err=err, width=0.5, lw=None, alpha=alpha, 
            color=col, ls=dash, capsize=capsize
            )

    # add dots to the all subplot
    x_vals = np.asarray([-0.17, 0.25, -0.25, 0.17]) # to spread dots out
    lw = 4
    ms = 200
    for s, _ in enumerate(stimtypes):
        lp_stim_data = [data[s] for data in lp_data]
        sorter = np.argsort(lp_stim_data)
        for i, idx in enumerate(sorter):
            x_val = s + x_vals[i]
            # white behind
            sub_ax_all.scatter(
                x=x_val, y=lp_stim_data[idx], s=ms, linewidth=lw, alpha=0.8, 
                color="white", zorder=10
                )
            
            # colored dots
            sub_ax_all.scatter(
                x=x_val, y=lp_stim_data[idx], s=ms, alpha=0.6, linewidth=0, 
                color=cols[idx], zorder=11
                )
            
            # dot borders
            sub_ax_all.scatter(
                x=x_val, y=lp_stim_data[idx], s=ms, color="None", 
                edgecolor=cols[idx], linewidth=lw, alpha=1, zorder=12
                )

    # add between stim significance 
    add_between_stim_sig(ax, sub_ax_all, stim_data_df, permpar)

    # add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=None, 
        planes=["", ""], xticks=[0, 1], ylab="Absolute fractional change", 
        kind="reg", xlab=""
        )
    
    # adjust plot details
    stimtype_names = stimtypes[:]
    stimtype_names[stimtypes.index("visflow")] = "visual\nflow"
    for sub_ax in fig.axes:
        y_max = sub_ax.get_ylim()[1]
        sub_ax.set_ylim([0, y_max])
        sub_ax.set_xticks([0, 1])
        sub_ax.set_xticklabels(
            stimtypes, weight="bold", rotation=45, ha="right"
            )
        sub_ax.tick_params(axis="x", bottom=False)
    sub_ax_all.set_xlim(ax[0, 0].get_xlim())
        
    plot_util.set_interm_ticks(
        np.asarray(sub_ax_all), 4, axis="y", share=False, weight="bold"
        )

    return ax
示例#7
0
def plot_ex_gabor_traces(ex_traces_df, stimpar, figpar, title=None):
    """
    plot_ex_gabor_traces(ex_traces_df, stimpar, figpar)

    Plots example Gabor traces.

    Required args:
        - ex_traces_df (pd.DataFrame):
            dataframe with a row for each ROI, and the following columns, 
            in addition to the basic sess_df columns: 
            - time_values (list): values for each frame, in seconds
            - roi_ns (list): selected ROI number
             - traces_sm (list): selected ROI sequence traces, smoothed, with 
                dims: seq x frames
            - trace_stat (list): selected ROI trace mean or median
        - stimpar (dict):
            dictionary with keys of StimPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
    
    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if stimpar["stimtype"] != "gabors":
        raise ValueError("Expected stimpar['stimtype'] to be 'gabors'.")

    group_columns = ["lines", "planes"]
    n_per = np.max(
        [len(lp_df) for _, lp_df in ex_traces_df.groupby(group_columns)]
        )
    per_rows, per_cols = math_util.get_near_square_divisors(n_per)
    n_per = per_rows * per_cols

    figpar = sess_plot_util.fig_init_linpla(
        figpar, kind="traces", n_sub=per_rows
        )
    figpar["init"]["subplot_hei"] = 1.36
    figpar["init"]["subplot_wid"] = 2.47
    figpar["init"]["ncols"] = per_cols * 2

    fig, ax = plot_util.init_fig(
        plot_helper_fcts.N_LINPLA * n_per, **figpar["init"]
        )
    if title is not None:
        fig.suptitle(title, y=1.03, weight="bold")

    ylims = np.full(ax.shape + (2, ), np.nan)
    
    logger.info("Plotting individual traces...", extra={"spacing": TAB})
    raster_zorder = -12
    for (line, plane), lp_df in ex_traces_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        for i, idx in enumerate(lp_df.index):
            row_idx = int(pl * per_rows + i % per_rows)
            col_idx = int(li * per_cols + i // per_rows)
            sub_ax = ax[row_idx, col_idx]

            ylims[row_idx, col_idx] = plot_ex_gabor_roi_traces(
                sub_ax, 
                lp_df.loc[idx],
                col=col,
                dash=dash,
                zorder=raster_zorder - 1
            )

        time_values = np.asarray(lp_df.loc[lp_df.index[-1], "time_values"])
        
    sess_plot_util.format_linpla_subaxes(ax, fluor="dff", 
        area=False, datatype="roi", sess_ns=None, xticks=None, kind="traces", 
        modif_share=False)

   # fix x ticks and lims
    for sub_ax in ax.reshape(-1):
        xlims = [time_values[0], time_values[-1]]
        xticks = np.linspace(*xlims, 6)
        sub_ax.set_xticks(xticks)
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold", skip=False)
    for sub_ax in ax.reshape(-1):
        sub_ax.set_xlim(xlims)
    
    # reset y limits
    for r in range(ax.shape[0]):
        for c in range(ax.shape[1]):
            if not np.isfinite(ylims[r, c].sum()):
                continue
            ax[r, c].set_ylim(ylims[r, c])

    plot_util.set_interm_ticks(
        ax, 2, axis="y", share=False, weight="bold", update_ticks=True
        )  

    # rasterize the gray lines
    logger.info("Rasterizing individual traces...", extra={"spacing": TAB})
    for sub_ax in ax.reshape(-1):
        sub_ax.set_rasterization_zorder(raster_zorder)

    return ax
示例#8
0
def add_between_sess_sig(ax, data_df, permpar, data_col="diff_stats", 
                         highest=None, ctrl=False, p_val_prefix=False, 
                         dry_run=False):
    """
    add_between_sess_sig(ax, data_df, permpar)

    Plot significance markers for significant comparisons between sessions.

    Required args:
        - ax (plt Axis): 
            axis
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_col} (list): data stats (me, err)
            for session comparisons, e.g. 1v2:
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple

    Optional args:
        - data_col (str):
            data column name in data_df
            default: "diff_stats"
        - highest (list):
            highest point for each line/plane, in order
            default: None
        - ctrl (bool): 
            if True, significance symbols should use control colour and symbol
            default: False
        - p_val_prefix (bool):
            if True, p-value columns start with data_col as a prefix 
            "{data_col}_p_vals_{}v{}".
            default: False
        - dry_run (bool):
            if True, a dry-run is done to get highest values, but nothing is 
            plotted or logged.
            default: False

    Returns:
    - highest (list):
        highest point for each line/plane, in order, after plotting
    """

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    prefix = f"{data_col}_" if p_val_prefix else ""

    if not dry_run:
        logger.info(f"{comp_info}:", extra={"spacing": "\n"})

    for pass_n in [0, 1]: # add significance markers on the second pass
        linpla_grps = list(data_df.groupby(["lines", "planes"]))
        if highest is None:
            highest = [0] * len(linpla_grps)
        elif len(highest) != len(linpla_grps):
            raise ValueError("If highest is provided, it must contain as "
                "many values as line/plane groups in data_df.")

        for l, ((line, plane), lp_df) in enumerate(linpla_grps):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            if ctrl:
                col = "gray"

            lp_sess_ns = np.sort(lp_df["sess_ns"].unique())
            for sess_n in lp_sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) != 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")    

            sig_p_vals, sig_strs, sig_xs = [], [], []
            lp_sig_str = f"{line_plane_name:6} (between sessions):"
            for i, sess_n1 in enumerate(lp_sess_ns):
                row_1s = lp_df.loc[lp_df["sess_ns"] == sess_n1]
                for sess_n2 in lp_sess_ns[i + 1: ]:
                    row_2s = lp_df.loc[lp_df["sess_ns"] == sess_n2]
                    if len(row_1s) != 1 or len(row_2s) != 1:
                        raise RuntimeError(
                            "Expected exactly one row per session."
                            )
                    row1 = row_1s.loc[row_1s.index[0]]
                    row2 = row_2s.loc[row_2s.index[0]]

                    row1_highest = row1[data_col][0] + row1[data_col][-1]
                    row2_highest = row2[data_col][0] + row2[data_col][-1]      
                    highest[l] = np.nanmax(
                        [highest[l], row1_highest, row2_highest]
                        )
                    
                    if dry_run:
                        continue
                    
                    p_val = row1[
                        f"{prefix}p_vals_{int(sess_n1)}v{int(sess_n2)}"
                        ]
                    side = np.sign(row2[data_col][0] - row1[data_col][0])

                    sig_str = misc_analys.get_sig_symbol(
                        p_val, sensitivity=sensitivity, side=side, 
                        tails=permpar["tails"], p_thresh=permpar["p_val"], 
                        ctrl=ctrl
                        )

                    if len(sig_str):
                        sig_p_vals.append(p_val)
                        sig_strs.append(sig_str)
                        sig_xs.append([sess_n1, sess_n2])
                    
                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n1}v{sess_n2}: "
                        f"{p_val:.5f}{sig_str:3}"
                        )

            if dry_run:
                continue

            n = len(sig_p_vals)
            ylims = sub_ax.get_ylim()
            prop = np.diff(ylims)[0] / 8.0

            if pass_n == 0:
                logger.info(lp_sig_str, extra={"spacing": TAB})
                if n == 0:
                    continue

                # count number of significant comparisons, and adjust y limits
                ylims = [
                    ylims[0], np.nanmax(
                        [ylims[1], highest[l] + prop * (n + 1)])
                        ]
                sub_ax.set_ylim(ylims)
            else:
                if n == 0:
                    continue

                if ctrl:
                    mark_rel_y = 0.22
                    fontsize = 14
                else:
                    mark_rel_y = 0.18
                    fontsize = 20

                # add significance markers sequentially, on second pass
                y_pos = highest[l]
                for s, (p_val, sig_str, sig_x) in enumerate(
                    zip(sig_p_vals, sig_strs, sig_xs)
                ):  
                    y_pos = highest[l] + (s + 1) * prop
                    plot_util.plot_barplot_signif(
                        sub_ax, sig_x, y_pos, rel_y=0.11, color=col, lw=3, 
                        mark_rel_y=mark_rel_y, mark=sig_str, fontsize=fontsize
                        )
                highest[l] = np.nanmax([highest[l], y_pos])

                if y_pos > ylims[1]:
                    sub_ax.set_ylim([ylims[0], y_pos * 1.1])

    return highest
示例#9
0
def plot_idxs(idx_df, sesspar, figpar, plot="items", density=True, n_bins=40, 
              title=None, size="reg"):
    """
    plot_idxs(idx_df, sesspar, figpar)

    Returns exact color for a specific line.

    Required args:
        - idx_df (pd.DataFrame):
            dataframe with indices for different line/plane combinations, and 
            the following columns, in addition to the basic sess_df columns:
            - rand_idx_binned (list): bin counts for the random ROI indices
            - bin_edges (list): first and last bin edge
            - CI_edges (list): confidence interval limit values
            - CI_perc (list): confidence interval percentile limits
            if plot == "items":
            - roi_idx_binned (list): bin counts for the ROI indices
            if plot == "percs":
            - perc_idx_binned (list): bin counts for the ROI index percentiles
            optionally:
            - n_signif_lo (int): number of significant ROIs (low) 
            - n_signif_hi (int): number of significant ROIs (high)

        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - plot (str): 
            type of data to plot ("items" or "percs")
            default: "items"
        - density (bool): 
            if True, histograms are plotted as densities
            default: True
        - n_bins (int): 
            number of bins to use in histograms
            default: 40
        - title (str): 
            plot title
            default: None
        - size (str): 
            plot size ("reg", "small" or "tall")
            default: "reg"
        
    Returns:
        - ax (2D array): 
            array of subplots
    """

    if plot == "items":
        data_key = "roi_idx_binned"
        CI_key = "CI_edges"
    elif plot == "percs":
        data_key = "perc_idx_binned"
        CI_key = "CI_perc"
    else:
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_df)

    n_plots = len(sess_ns) * 4
    figpar["init"]["sharey"] = "row"
    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", 
        n_sub=len(sess_ns), sharex=(plot == "percs"))

    y = 1
    if size == "reg":
        subplot_hei = 3.2
        subplot_wid = 5.5
    elif size == "small":
        y = 1.04
        subplot_hei = 2.40
        subplot_wid = 3.75
        figpar["init"]["gs"] = {"hspace": 0.25, "wspace": 0.30}
    elif size == "tall":
        y = 0.98
        subplot_hei = 5.3
        subplot_wid = 5.55
    else:
        gen_util.accepted_values_error("size", size, ["reg", "small", "tall"])
    
    figpar["init"]["subplot_hei"] = subplot_hei
    figpar["init"]["subplot_wid"] = subplot_wid
    figpar["init"]["sharey"] = "row"
    
    fig, ax = plot_util.init_fig(n_plots, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=y, weight="bold")

    for (line, plane), lp_df in idx_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for s, sess_n in enumerate(sess_ns):
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) == 0:
                continue
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected sess_ns to be unique per line/plane."
                    )
            row = rows.loc[rows.index[0]]

            sub_ax = ax[s + pl * len(sess_ns), li]

            # get percentage significant label
            perc_label = None
            if "n_signif_lo" in row.keys() and "n_signif_hi" in row.keys():
                n_sig_lo, n_sig_hi = row["n_signif_lo"], row["n_signif_hi"]
                nrois = np.sum(row["nrois"])
                perc_signif = np.sum([n_sig_lo, n_sig_hi]) / nrois * 100
                perc_label = (f"{perc_signif:.2f}% sig\n"
                    f"({n_sig_lo}-/{n_sig_hi}+ of {nrois})")                

            plot_stim_idx_hist(
                sub_ax, row[data_key], row[CI_key], n_bins=n_bins, 
                rand_data=row["rand_idx_binned"], 
                orig_edges=row["bin_edges"], 
                plot=plot, col=col, density=density, perc_label=perc_label)
            
            if size == "small":
                sub_ax.legend(fontsize="small")

    # Add plane, line info to plots
    y_lab = "Density" if density else f"N ROIs" 
    sess_plot_util.format_linpla_subaxes(ax, datatype="roi", ylab=y_lab, 
        xticks=None, sess_ns=None, kind="idx", modif_share=False, 
        xlab="Index", single_lab=True)

    # Add indices after setting formatting
    if plot == "percs":
        nticks = 5
        xticks = [int(np.around(x, 0)) for x in np.linspace(0, 100, nticks)]
        for sub_ax in ax[-1]:
            sub_ax.set_xticks(xticks)
            sub_ax.set_xticklabels(xticks, weight="bold")
    
    elif plot == "items":
        nticks = 3
        plot_util.set_interm_ticks(
            ax, nticks, axis="x", weight="bold", share=False, skip=False
            )
    
    else:
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])
        
    return ax
示例#10
0
def plot_roi_correlations(corr_df, figpar, title=None, log_scale=True):
    """
    plot_roi_correlations(corr_df, figpar)

    Plots correlation histograms.

    Required args:
        - corr_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the 
            following columns, in addition to the basic sess_df columns:
            - bin_edges (list): first and last bin edge
            - corrs_binned (list): number of correlation values per bin
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - log_scale (bool):
            if True, a near logarithmic scale is used for the y axis (with a 
            linear range to reach 0, and break marks to mark the transition 
            from linear to log range)
            default: True

    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = np.arange(corr_df.sess_ns.min(), corr_df.sess_ns.max() + 1)
    n_sess = len(sess_ns)

    figpar = sess_plot_util.fig_init_linpla(figpar,
                                            kind="prog",
                                            n_sub=len(sess_ns))
    figpar["init"]["subplot_hei"] = 3.0
    figpar["init"]["subplot_wid"] = 2.8
    figpar["init"]["sharex"] = log_scale
    if log_scale:
        figpar["init"]["sharey"] = True

    fig, ax = plot_util.init_fig(4 * len(sess_ns), **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=1.02, weight="bold")

    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         ylab="Density",
                                         xlab="Correlation",
                                         sess_ns=sess_ns,
                                         kind="prog",
                                         single_lab=True)

    log_base = 2
    for (line, plane), lp_df in corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        for s, sess_n in enumerate(sess_ns):
            sess_rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(sess_rows) == 0:
                continue
            elif len(sess_rows) > 1:
                raise RuntimeError("Expected exactly one row.")
            sess_row = sess_rows.loc[sess_rows.index[0]]

            sub_ax = ax[pl, s + li * n_sess]

            weights = np.asarray(sess_row["corrs_binned"])

            bin_edges = np.linspace(*sess_row["bin_edges"], len(weights) + 1)

            sub_ax.hist(bin_edges[:-1],
                        bin_edges,
                        weights=weights,
                        color=col,
                        alpha=0.6,
                        density=True)
            sub_ax.axvline(0,
                           ls=plot_helper_fcts.VDASH,
                           c="k",
                           lw=3.0,
                           alpha=0.5)

            sub_ax.spines["bottom"].set_visible(True)
            sub_ax.tick_params(axis="x", which="both", bottom=True, top=False)

            if log_scale:
                sub_ax.set_yscale("log", base=log_base)
                sub_ax.set_xlim(-1, 1)
            else:
                sub_ax.autoscale(axis="x", tight=True)

            sub_ax.autoscale(axis="y", tight=True)

    if log_scale:  # update x ticks
        set_symlog_scale(ax, log_base=log_base, col_per_grp=n_sess, n_ticks=4)

    else:  # update x and y ticks
        for i in range(ax.shape[0]):
            for j in range(int(ax.shape[1] / n_sess)):
                sub_axes = ax[i, j * n_sess:(j + 1) * n_sess]
                plot_util.set_interm_ticks(sub_axes,
                                           4,
                                           axis="y",
                                           share=True,
                                           update_ticks=True)

    plot_util.set_interm_ticks(ax,
                               4,
                               axis="x",
                               share=log_scale,
                               update_ticks=True,
                               fontweight="bold")

    return ax
示例#11
0
def plot_snr_sigmeans_nrois(data_df,
                            figpar,
                            datatype="snrs",
                            title="ROI SNRs"):
    """
    plot_snr_sigmeans_nrois(data_df, figpar)

    Plots SNR, signal means or number of ROIs, depending on the case.

    Required args:
        - data_df (pd.DataFrame):
            dataframe with SNR, signal mean or number of ROIs data for each 
            session, in addition to the basic sess_df columns
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - datatype (str):
            type of data to plot, also corresponding to column name
            default: "snrs"
        - title (str):
            plot title
            default: "ROI SNRs"

    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = np.arange(data_df.sess_ns.min(), data_df.sess_ns.max() + 1)

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")
    figpar["init"]["sharey"] = "row"

    figpar["init"]["subplot_hei"] = 4.4
    figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.2}
    if datatype != "nrois":
        figpar["init"]["subplot_wid"] = 3.2
    else:
        figpar["init"]["subplot_wid"] = 2.5

    fig, ax = plot_util.init_fig(4, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=0.97, weight="bold")

    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)

        sub_ax = ax[pl, li]

        if datatype == "snrs":
            sub_ax.axhline(y=1,
                           ls=plot_helper_fcts.HDASH,
                           c="k",
                           lw=3.0,
                           alpha=0.5)
        elif datatype == "signal_means":
            sub_ax.axhline(y=0,
                           ls=plot_helper_fcts.HDASH,
                           c="k",
                           lw=3.0,
                           alpha=0.5)
        elif datatype != "nrois":
            gen_util.accepted_values_error("datatype", datatype,
                                           ["snrs", "signal_means", "nrois"])

        if datatype == "nrois":
            plot_nrois(sub_ax, lp_df, sess_ns=sess_ns, col=col, dash=dash)
            continue

        data = []
        use_sess_ns = []
        for sess_n in sess_ns:
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) > 0:
                use_sess_ns.append(sess_n)
                data.append(np.concatenate(rows[datatype].tolist()))

        sub_ax.boxplot(data,
                       positions=use_sess_ns,
                       notch=True,
                       patch_artist=True,
                       whis=[5, 95],
                       widths=0.6,
                       boxprops=dict(facecolor="white",
                                     color=col,
                                     linewidth=3.0),
                       capprops=dict(color=col, linewidth=3.0),
                       whiskerprops=dict(color=col, linewidth=3.0),
                       flierprops=dict(color=col,
                                       markeredgecolor=col,
                                       markersize=8),
                       medianprops=dict(color=col, linewidth=3.0))

    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         ylab="",
                                         xticks=sess_ns,
                                         kind="reg",
                                         single_lab=True)

    return ax
示例#12
0
def plot_roi_tracking(roi_mask_df, figpar, title=None):
    """
    plot_roi_tracking(roi_mask_df, figpar)
    
    Plots ROI tracking examples, for different session permutations, and union 
    across permutations.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)
            - "union_n_conflicts" (int): number of conflicts after union
            for "union", "fewest" and "most" tracked ROIs:
            - "{}_registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val),
                ordered by {}_sess_ns if "fewest" or "most"
            - "{}_n_tracked" (int): number of tracked ROIs
            for "fewest", "most" tracked ROIs:
            - "{}_sess_ns" (list): ordered session number 
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if len(roi_mask_df) != 1:
        raise ValueError("Expected only one row in roi_mask_df")
    roi_mask_row = roi_mask_df.loc[roi_mask_df.index[0]]

    columns = ["fewest", "most", "", "union"]

    figpar["init"]["ncols"] = len(columns)
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 5.05
    figpar["init"]["subplot_wid"] = 5.05
    figpar["init"]["gs"] = {"wspace": 0.06}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.905, 0.125, 0.06, 0.74] 

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)
    plot_util.remove_axis_marks(sub_ax_scale)
    sub_ax_scale.spines["left"].set_visible(True)

    if title is not None:
        fig.suptitle(title, y=1.05, weight="bold")

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    for c, column in enumerate(columns):
        sub_ax = ax[0, c]

        if c == 0:
            lp_col = plot_helper_fcts.get_line_plane_idxs(
                roi_mask_row["lines"], roi_mask_row["planes"]
                )[2]

            lp_name = plot_helper_fcts.get_line_plane_name(
                roi_mask_row["lines"], roi_mask_row["planes"]
                )
            sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
            log_info = f"Conflicts and matches for a {lp_name} example:"

        if column == "":
            sub_ax.set_axis_off()
            subplot_title = \
                "     Union - conflicts\n...   ====================>"
            sub_ax.set_title(subplot_title, fontweight="bold", y=0.5)
            continue
        else:
            plot_util.remove_axis_marks(sub_ax)
            for spine in ["right", "left", "top", "bottom"]:
                sub_ax.spines[spine].set_visible(True)

        if column in ["fewest", "most"]:
            y = 1.01
            ord_sess_ns = roi_mask_row[f"{column}_sess_ns"]
            ord_sess_ns_str = ", ".join([str(n) for n in ord_sess_ns])

            n_matches = int(roi_mask_row[f"{column}_n_tracked"])
            subplot_title = f"{n_matches} matches\n(sess {ord_sess_ns_str})"
            log_info = (f"{log_info}\n{TAB}"
                f"{column.capitalize()} matches (sess {ord_sess_ns_str}): "
                f"{n_matches}")
        
        elif column == "union":
            y = 1.04
            ord_sess_ns = roi_mask_row["sess_ns"]
            n_union = int(roi_mask_row[f"{column}_n_tracked"])
            n_conflicts = int(roi_mask_row[f"{column}_n_conflicts"])
            n_matches = n_union - n_conflicts

            subplot_title = f"{n_matches} matches"
            log_info = (f"{log_info}\n{TAB}"
                "Union - conflicts: "
                f"{n_union} - {n_conflicts} = {n_matches} matches"
                )

        sub_ax.set_title(subplot_title, fontweight="bold", y=y)

        roi_masks = create_sess_roi_masks(
            roi_mask_row, 
            mask_key=f"{column}_registered_roi_mask_idxs"
            )
        
        for sess_n in roi_mask_row["sess_ns"]:
            col = sess_cols[int(sess_n)]
            s = ord_sess_ns.index(sess_n)
            add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha)

    # add scale marker
    hei_len = roi_mask_row["roi_mask_shapes"][1]
    add_scale_marker(
        sub_ax_scale, side_len=hei_len, ori="vertical", quadrant=3, fontsize=20
        )

    logger.info(log_info, extra={"spacing": "\n"})

    # add legend
    add_sess_col_leg(
        ax[0, columns.index("")], 
        sess_cols, 
        bbox_to_anchor=(0.67, 0.3), 
        alpha=alpha
        )

    return ax
示例#13
0
def plot_roi_masks_overlayed(roi_mask_df, figpar, title=None):
    """
    plot_roi_masks_overlayed(roi_mask_df, figpar)

    Plots ROI masks overlayed across sessions, optionally cropped.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)
            
            and optionally, if cropping:
            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]

        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    crop = "crop_fact" in roi_mask_df.columns

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 5.2
    figpar["init"]["subplot_wid"] = 5.2
    figpar["init"]["gs"] = {"wspace": 0.03, "hspace": 0.32}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.885, 0.11, 0.1, 0.33]
    if crop: # move to the left
        new_axis_coords[0] = 0.04

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)
    plot_util.remove_axis_marks(sub_ax_scale)
    sub_ax_scale.spines["left"].set_visible(True)

    if title is not None:
        fig.suptitle(title, y=0.95, weight="bold")

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    hei_lens = []
    for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]
        
        roi_masks = create_sess_roi_masks(lp_row, crop=crop)
        hei_lens.append(roi_masks.shape[1])

        for s, sess_n in enumerate(lp_row["sess_ns"]):
            col = sess_cols[int(sess_n)]
            add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha)

    # add legend
    add_sess_col_leg(
        ax[0, 1], sess_cols, bbox_to_anchor=(0.7, -0.01), alpha=alpha
        )

    # add scale marker
    hei_lens = np.unique(hei_lens)
    if len(hei_lens) != 1:
        raise NotImplementedError(
            "Adding scale bar not implemented if ROI mask image heights are "
            "different for different planes."
            )
    quadrant = 1 if crop else 3
    add_scale_marker(
        sub_ax_scale, side_len=hei_lens[0], ori="vertical", quadrant=quadrant, 
        fontsize=22
        )
 
    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, ylab="", kind="map")

    return ax
示例#14
0
def plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar, title=None):
    """
    plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar)

    Plots ROI mask contours overlayed over imaging planes, and ROI masks 
    overlayed over each other across sessions.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 

            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_idxs" (list): list of mask indices for each session, 
                and each ROI (sess x (ROI, hei, wid) x val) (not registered)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)

            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]

        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    n_lines = len(roi_mask_df["lines"].unique())
    n_planes = len(roi_mask_df["planes"].unique())

    sess_cols = get_sess_cols(roi_mask_df)
    n_sess = len(sess_cols)
    n_cols = n_sess * n_lines

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 2.3
    figpar["init"]["subplot_wid"] = 2.3
    figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.2}
    figpar["init"]["ncols"] = n_cols

    fig, ax = plot_util.init_fig(n_cols * n_planes * 2, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.93, weight="bold")

    crop = "crop_fact" in roi_mask_df.columns

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    raster_zorder = -12

    for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        lp_col = plot_helper_fcts.get_line_plane_idxs(line, plane)[2]
        lp_name = plot_helper_fcts.get_line_plane_name(line, plane)

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]

        # identify subplots
        base_row = (pl % n_planes) * n_planes
        base_col = (li % n_lines) * n_lines

        ax_grp = ax[base_row : base_row + 2, base_col : base_col + n_sess + 1]

        # add imaging planes and masks
        imaging_planes = add_proj_and_roi_masks(
            ax_grp, lp_row, sess_cols, crop=crop, alpha=alpha, 
            proj_zorder=raster_zorder - 1
            )

        # add markings
        shared_row = base_row + 1
        shared_col = base_col + int((n_sess - 1) // 2)
        shared_sub_ax = ax[shared_row, shared_col]

        if shared_col == 0:
            shared_sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
        else:
            lp_sub_ax = ax[shared_row, 0]
            lp_sub_ax.set_xlim([0, 1])
            lp_sub_ax.set_ylim([0, 1])
            lp_sub_ax.text(
                0.5, 0.5, lp_name, fontweight="bold", color=lp_col, 
                ha="center", va="center", fontsize="x-large"
                )

        # add scale bar
        if n_sess < 2:
            raise NotImplementedError(
                "Scale bar placement not implemented for fewer than 2 "
                "sessions."
                )
        scale_ax = ax[shared_row, -1]
        wid_len = imaging_planes[0].shape[-1]
        add_scale_marker(
            scale_ax, side_len=wid_len, ori="horizontal", quadrant=1, 
            fontsize=20
            )

    logger.info("Rasterizing imaging plane images...", extra={"spacing": TAB})
    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):
            sub_ax = ax[i, j]
            plot_util.remove_axis_marks(sub_ax)
            if not(i % 2):
                sub_ax.set_rasterization_zorder(raster_zorder)

    # add legend
    if n_sess < 2:
        raise NotImplementedError(
            "Legend placement not implemented for fewer than 2 sessions."
            )
    add_sess_col_leg(
        ax[-1, -1], sess_cols, bbox_to_anchor=(1, 0.6), alpha=alpha, 
        fontsize="small"
        )

    return ax
示例#15
0
def plot_idx_corr_scatterplots(idx_corr_df,
                               permpar,
                               figpar,
                               permute="sess",
                               title=None):
    """
    plot_idx_corr_scatterplots(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlation scatterplots for individual session 
    comparisons.

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for correlation data (normalized if corr_type is "diff_corr") for 
            session comparisons (x, y), e.g. 1v2
            - binned_rand_stats (list): number of random correlation values per 
                bin (xs x ys)
            - corr_data_xs (list): USI values for x
            - corr_data_ys (list): USI values for y
            - corrs (float): correlation between session data (x and y)
            - p_vals (float): p-value for correlation, corrected for 
                multiple comparisons and tails
            - rand_corr_meds (float): median of the random correlations
            - raw_p_vals (float): p-value for intersession correlations
            - regr_coefs (float): regression correlation coefficient (slope)
            - regr_intercepts (float): regression correlation intercept
            - x_bin_mids (list): x mid point for each random correlation bin
            - y_bin_mids (list): y mid point for each random correlation bin

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    diffs = False
    if permute in ["sess", "all"]:
        diffs = True

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 4
    figpar["init"]["subplot_wid"] = 4
    figpar["init"]["gs"] = {"hspace": 0.4, "wspace": 0.4}

    fig, ax = plot_util.init_fig(4, **figpar["init"])

    if title is not None:
        fig.suptitle(title, fontweight="bold", y=0.97)

    sess_ns = None

    # first pass to plot
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_df) != 1:
            raise RuntimeError("Expected exactly one row.")
        lp_row = lp_df.loc[lp_df.index[0]]

        if sess_ns is None:
            sess_ns = lp_row["sess_ns"]
            xlabel = f"Session {sess_ns[0]} USIs"
            ylabel = f"Session {sess_ns[1]} USIs"
            if diffs:
                ylabel = f"Session {sess_ns[1]} - {sess_ns[0]} USIs"

        elif sess_ns != lp_row["sess_ns"]:
            raise RuntimeError("Expected all sess_ns to match.")

        density_data = [
            lp_row["x_bin_mids"], lp_row["y_bin_mids"],
            np.asarray(lp_row["binned_rand_stats"]).T
        ]
        sub_ax.contour(*density_data,
                       levels=6,
                       cmap="Greys",
                       zorder=-13,
                       linewidths=4)

        alpha = 0.3**(len(lp_row["corr_data_xs"]) / 300)
        sub_ax.scatter(lp_row["corr_data_xs"],
                       lp_row["corr_data_ys"],
                       color=col,
                       alpha=alpha,
                       lw=2,
                       s=35)

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         xticks=None,
                                         ylab=ylabel,
                                         xlab=xlabel,
                                         kind="reg")

    # second pass to add plot markings
    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    sig_str = ""
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
        sub_ax = ax[pl, li]

        # add markers back in (removed due to kind='reg')
        sub_ax.tick_params(axis="x", which="both", bottom=True, top=False)
        sub_ax.spines["bottom"].set_visible(True)

        lp_row = lp_df.loc[lp_df.index[0]]

        p_val_corr = lp_row["p_vals"]
        lp_sig_str = add_scatterplot_markers(sub_ax,
                                             permpar,
                                             lp_row["corrs"],
                                             lp_row["rand_corr_meds"],
                                             lp_row["regr_coefs"],
                                             lp_row["regr_intercepts"],
                                             lp_row["p_vals"],
                                             col=col,
                                             diffs=diffs)

        sig_str = (f"{sig_str}{TAB}{line_plane_name}: "
                   f"{p_val_corr:.5f}{lp_sig_str:3}")

    logger.info(sig_str, extra={"spacing": TAB})

    return ax
示例#16
0
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, 
                       title=None):
    """
    plot_perc_sig_usis(perc_sig_df, analyspar, figpar)

    Plots percentage of significant USIs.

    Required args:
        - perc_sig_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns:
            for sig in ["lo", "hi"]: for low vs high ROI indices
            - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100)
            - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation 
                over percent significant ROIs
            - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs 
            - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent 
                sig. ROIs
            - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for 
                percent sig. ROIs
            - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. 
                ROIs, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - title (str):
            plot title
            default: None
        
    Returns:
        - ax (2D array): 
            array of subplots
    """  

    perc_sig_df = perc_sig_df.copy(deep=True)

    nanpol = None if analyspar["rem_bad"] else "omit"

    sess_ns = perc_sig_df["sess_ns"].unique()
    if len(sess_ns) != 1:
        raise NotImplementedError(
            "Plotting function implemented for 1 session only."
            )

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, 
        sharex=True, sharey=True)

    figpar["init"]["sharey"] = True
    figpar["init"]["subplot_wid"] = 3.4
    figpar["init"]["gs"] = {"wspace": 0.18}
    if by_mouse:
        figpar["init"]["subplot_hei"] = 8.4
    else:
        figpar["init"]["subplot_hei"] = 3.5

    fig, ax = plot_util.init_fig(2, **figpar["init"])
    if title is not None:
        y = 0.98 if by_mouse else 1.07
        fig.suptitle(title, y=y, weight="bold")

    tail_order = ["Low tail", "High tail"]
    tail_keys = ["lo", "hi"]
    chance = permpar["p_val"] / 2 * 100

    ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40)
    n_linpla = plot_helper_fcts.N_LINPLA

    comp_info = misc_analys.get_comp_info(permpar)
    
    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for t, (tail, key) in enumerate(zip(tail_order, tail_keys)):
        sub_ax = ax[0, t]
        sub_ax.set_title(tail, fontweight="bold")
        sub_ax.set_ylim(ylims)

        # replace bottom spine with line at 0
        sub_ax.spines['bottom'].set_visible(False)
        sub_ax.axhline(y=0, c="k", lw=4.0)

        data_key = f"perc_sig_{key}_idxs"

        CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan)
        CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan)

        tail_sig_str = f"{tail:9}:"
        linpla_names = []
        for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            x_index = 2 * li + pl
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            linpla_names.append(linpla_name)
            
            if len(lp_df) == 0:
                continue
            elif len(lp_df) > 1 and not by_mouse:
                raise RuntimeError("Expected a single row per line/plane.")

            lp_df = lp_df.sort_values("mouse_ns") # sort by mouse
            df_indices = lp_df.index.tolist()

            if by_mouse:
                # plot means or medians per mouse
                mouse_data = lp_df[data_key].to_numpy()
                mouse_cols = plot_util.get_hex_color_range(
                    len(lp_df), col=col, 
                    interval=plot_helper_fcts.MOUSE_COL_INTERVAL
                    )
                mouse_data_mean = math_util.mean_med(
                    mouse_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                CI_dummy = np.repeat(mouse_data_mean, 2)
                plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, 
                    x=x_index, width=0.6, med_col=col, med_rat=0.01)
            else:
                # collect confidence interval data
                row = lp_df.loc[df_indices[0]]
                mouse_cols = [col]
                CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[
                    np.asarray([0, 2])
                    ]
                CI_meds[x_index] = row[f"{data_key}_null_CIs"][1]

            if by_mouse:
                perc_p_vals = []
                rel_y = 0.05
            else:
                tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: "
                rel_y = 0.1

            for df_i, mouse_col in zip(df_indices, mouse_cols):
                # plot UFOs
                err = None
                no_line = True
                if not by_mouse:
                    err = perc_sig_df.loc[df_i, f"{data_key}_stds"]
                    no_line = False
                # indicate bootstrapped error with wider capsize
                plot_util.plot_ufo(
                    sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err,
                    color=mouse_col, capsize=8, no_line=no_line
                    )

                # add significance markers
                p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"]
                perc = perc_sig_df.loc[df_i, data_key]
                nrois = np.sum(perc_sig_df.loc[df_i, "nrois"])
                side = np.sign(perc - chance)
                sensitivity = misc_analys.get_binom_sensitivity(
                    nrois, null_perc=chance, side=side
                    )                
                sig_str = misc_analys.get_sig_symbol(
                    p_val, sensitivity=sensitivity, side=side, 
                    tails=permpar["tails"], p_thresh=permpar["p_val"]
                    )

                if len(sig_str):
                    perc_high = perc + err if err is not None else perc
                    plot_util.add_signif_mark(sub_ax, x_index, perc_high, 
                        rel_y=rel_y, color=mouse_col, fontsize=24, 
                        mark=sig_str) 

                if by_mouse:
                    perc_p_vals.append(
                        (int(np.around(perc)), p_val, sig_str)
                    )
                else:
                    tail_sig_str = (
                        f"{tail_sig_str}{p_val:.5f}{sig_str:3}"
                        )

            if by_mouse: # sort p-value logging by percentage value
                tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: "
                order = np.argsort([vals[0] for vals in perc_p_vals])
                for i in order:
                    perc, p_val, sig_str = perc_p_vals[i]
                    perc_str = f"(~{perc}%)"
                    tail_sig_str = (
                        f"{tail_sig_str}{TAB}{perc_str:6} "
                        f"{p_val:.5f}{sig_str:3}"
                        )
                
        # add chance information
        if by_mouse:
            sub_ax.axhline(
                y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, 
                zorder=-12
                )
        else:
            plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, 
                x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12)

        logger.info(tail_sig_str, extra={"spacing": TAB})
    
    for sub_ax in fig.axes:
        sub_ax.tick_params(axis="x", which="both", bottom=False) 
        plot_util.set_ticks(
            sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2)
        sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold")

    ax[0, 0].set_ylabel("%", fontweight="bold")
    plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True)

    # adjustment if tick interval is repeated in the negative
    if ax[0, 0].get_ylim()[0] < 0:
        ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]])

    return ax
示例#17
0
def plot_sess_traces(data_df, analyspar, sesspar, figpar, 
                     trace_col="trace_stats", row_col="sess_ns", 
                     row_order=None, split="by_exp", title=None, size="reg"):
    """
    plot_sess_traces(data_df, analyspar, sesspar, figpar) 
    
    Plots traces from dataframe.

    Required args:
        - data_df (pd.DataFrame):
            traces data frame with, in addition to the basic sess_df columns, 
            columns specified by trace_col, row_col, and a "time_values" column
        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - trace_col (str):
             dataframe column containing trace statistics, as 
             split x ROIs x frames x stats 
             default: "trace_stats"
        - row_col (str):
            dataframe column specifying the variable that defines rows 
            within each line/plane
            default: "sess_ns"
        - row_order (list):
            ordered list specifying the order in which to plot from row_col.
            If None, automatic sorting order is used.
            default: None 
        - split (str):
            data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or 
            "stim_offset"
            default: False
        - title (str):
            plot title
            default: None
        - size (str):
            subplot sizes
            default: "reg"

    Returns:
        - ax (2D array): 
            array of subplots
    """
    
    # retrieve session numbers, and infer row_order, if necessary
    sess_ns = None
    if row_col == "sess_ns":
        sess_ns = row_order
        if row_order is None:
            row_order = misc_analys.get_sess_ns(sesspar, data_df)

    elif row_order is None:
        row_order = data_df[row_col].unique()

    figpar = sess_plot_util.fig_init_linpla(
        figpar, kind="traces", n_sub=len(row_order), sharey=False
        )

    if size == "small":
        figpar["init"]["subplot_hei"] = 1.51
        figpar["init"]["subplot_wid"] = 3.7
    elif size == "wide":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 4.8
        figpar["init"]["gs"] = {"wspace": 0.3, "hspace": 0.5}
    elif size == "reg":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 3.4
    else:
        gen_util.accepted_values_error("size", size, ["small", "wide", "reg"])

    fig, ax = plot_util.init_fig(len(row_order) * 4, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=1.0, weight="bold")

    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for r, row_val in enumerate(row_order):
            rows = lp_df.loc[lp_df[row_col] == row_val]
            if len(rows) == 0:
                continue
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected row_order instances to be unique per line/plane."
                    )
            row = rows.loc[rows.index[0]]

            sub_ax = ax[r + pl * len(row_order), li]

            if line == "L2/3-Cux2":
                exp_col = "darkgray" # oddly, lighter than gray
            else:
                exp_col = "gray"

            plot_traces(
                sub_ax, row["time_values"], row[trace_col], split=split, 
                col=col, ls=dash, exp_col=exp_col, lab=False
                )
            
    for sub_ax in ax.reshape(-1):
        plot_util.set_minimal_ticks(sub_ax, axis="y")

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=False, datatype="roi", sess_ns=sess_ns, xticks=None, 
        kind="traces", modif_share=False)

   # fix x ticks and lims
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold")
    xlims = [np.min(row["time_values"]), np.max(row["time_values"])]
    if split != "by_exp":
        xlims = [-xlims[1], xlims[1]]
    sub_ax.set_xlim(xlims)

    return ax
示例#18
0
def plot_ex_roi_hists(ex_idx_df, sesspar, permpar, figpar, title=None):
    """
    plot_ex_roi_hists(ex_idx_df, sesspar, permpar, figpar)

    Plot example ROI histograms.

    Required args:
        - ex_idx_df (pd.DataFrame):
            dataframe with a row for the example ROI, and the following 
            columns, in addition to the basic sess_df columns:
            - rand_idx_binned (list): bin counts for the random ROI indices
            - bin_edges (list): first and last bin edge
            - CI_edges (list): confidence interval limit values
            - CI_perc (list): confidence interval percentile limits
        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - plot (str): 
            type of data to plot ("items" or "percs")
            default: "items"

    Returns:
        - ax (2D array): 
            array of subplots
    """
    
    ex_idx_df = copy.deepcopy(ex_idx_df) # add dummy binned_roi_idxs
    ex_idx_df["roi_idx_binned"] = [
        np.zeros_like(rand_idx_binned) 
        for rand_idx_binned in ex_idx_df["rand_idx_binned"].tolist()
    ]
    
    with gen_util.TempWarningFilter("invalid value", RuntimeWarning):
        ax = plot_idxs(
            ex_idx_df, sesspar, figpar, plot="items", title=title, size="tall", 
            density=True, n_bins=40)

    # adjust x axes
    for sub_ax in ax.reshape(-1):
        sub_ax.set_xticks([-0.5, 0, 0.5])
        sub_ax.set_xticklabels(["-0.5", "0", "0.5"])

    # add lines and labels
    for (line, plane), lp_df in ex_idx_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)

        if len(lp_df) == 0:
            continue
        elif len(lp_df) > 1:
            raise RuntimeError("Expected at most one row per line/plane.")
        row = lp_df.loc[lp_df.index[0]]

        sub_ax = ax[pl, li]
        xlims = sub_ax.get_xlim()

        # add CI markers
        for c, (CI_val, CI_perc) in enumerate(
            zip(row["CI_edges"], row["CI_perc"])
            ):

            sub_ax.axvline(
                CI_val, ls=plot_helper_fcts.VDASH, c="red", lw=3.0, alpha=1.0, 
                label=f"p{CI_perc:0.2f}")
            sub_ax.axvspan(
                CI_val, xlims[c], color=plot_helper_fcts.DARKRED, alpha=0.1, 
                lw=0, zorder=-13
                )
        
        ex_perc = row["roi_idx_percs"]

        sensitivity = misc_analys.get_sensitivity(permpar)
        sig_str = misc_analys.get_sig_symbol(
            ex_perc, percentile=True, sensitivity=sensitivity, 
            p_thresh=permpar["p_val"]
            )

        sub_ax.axvline(
            x=row["roi_idxs"], ls=plot_helper_fcts.VDASH, c=col, lw=3.0, 
            alpha=0.8, label=f"p{ex_perc:0.2f}{sig_str}"
            )

        sub_ax.axvline(
            x=0, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5
            )


        # reset the x limits
        sub_ax.set_xlim(xlims)

        sub_ax.legend()
    
    return ax
示例#19
0
def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, 
                   between_sess_sig=True, data_col="diff_stats", 
                   decoder_data=False, title=None, wide=False):
    """
    plot_sess_data(data_df, analyspar, sesspar, permpar, figpar)

    Plots errorbar data across sessions.

    Required args:
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_key} (list): data stats (me, err)
            - null_CIs (list): adjusted null CI for data
            - raw_p_vals (float): uncorrected p-value for data within 
                sessions
            - p_vals (float): p-value for data within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for data 
                differences between sessions 
            - p_vals_{}v{} (float): p-value for data between sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True         
        - data_col (str):
            dataframe column in which data to plot is stored
            default: "diff_stats"
        - decoder_data (bool):
            if True, data plotted is decoder data
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False
        
    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, data_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)
    
    sharey = True if decoder_data else "row"
    figpar["init"]["sharey"] = sharey
    figpar["init"]["subplot_hei"] = 4.4
    figpar["init"]["gs"] = {"hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"]["wspace"] = 0.3
    else:
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.97, weight="bold")

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    for pass_n in [0, 1]: # add significance markers on the second pass
        if pass_n == 1:
            logger.info(f"{comp_info}:", extra={"spacing": "\n"})
            
        for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            sess_indices = []
            lp_sess_ns = []
            for sess_n in sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    sess_indices.append(rows.index[0])
                    lp_sess_ns.append(sess_n)
                elif len(rows) > 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")

            data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices])

            if pass_n == 0:
                # plot errorbars
                plot_util.plot_errorbars(
                    sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, 
                    alpha=0.8, xticks="auto", line_dash=dash
                    )

            if pass_n == 1:
                # plot CIs
                CIs = np.asarray(
                    [lp_df.loc[i, "null_CIs"] for i in sess_indices]
                    )
                CI_meds = CIs[:, 1]
                CIs = CIs[:, np.asarray([0, 2])]

                plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, 
                    width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, 
                    zorder=-12)

                # add significance markers within sessions
                y_maxes = data[:, 0] + data[:, -1]
                sides = [
                    np.sign(sub[0] - CI_med) 
                    for sub, CI_med in zip(data, CI_meds)
                    ]
                p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices]
                lp_sig_str = f"{line_plane_name:6} (within session):"
                for s, sess_n in enumerate(lp_sess_ns):
                    sig_str = misc_analys.get_sig_symbol(
                        p_vals_corr[s], sensitivity=sensitivity, side=sides[s], 
                        tails=permpar["tails"], p_thresh=permpar["p_val"]
                        )

                    if len(sig_str):
                        plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], 
                            rel_y=0.15, color=col, mark=sig_str)  

                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n}: "
                        f"{p_vals_corr[s]:.5f}{sig_str:3}"
                        )

                logger.info(lp_sig_str, extra={"spacing": TAB})
        
    if between_sess_sig:
        add_between_sess_sig(ax, data_df, permpar, data_col=data_col)

    area, ylab = True, None
    if decoder_data:
        area = False
        if "balanced" in data_col:
            ylab = "Balanced accuracy (%)" 
        else:
            ylab = "Accuracy %"

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", 
        xticks=sess_ns, modif_share=False)

    return ax
示例#20
0
def plot_tracked_idxs(idx_only_df, sesspar, figpar, title=None, wide=False):
    """
    plot_tracked_idxs(idx_only_df, sesspar, figpar)

    Plots tracked ROI USIs as individual lines.

    Required args:
        - idx_only_df (pd.DataFrame):
            dataframe with one row per (mouse/)session/line/plane, and the 
            following columns, in addition to the basic sess_df columns:
            - roi_idxs (list): index for each ROI

        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False

    Returns:
        - ax (2D array): 
            array of subplots    
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_only_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    figpar["init"]["subplot_hei"] = 4.1
    figpar["init"]["subplot_wid"] = 2.5
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["gs"]["wspace"] = 0.25

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    for (line, plane), lp_df in idx_only_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        # mouse_ns
        lp_mouse_ns = sorted(lp_df["mouse_ns"].unique())

        lp_data = []
        for mouse_n in lp_mouse_ns:
            mouse_df = lp_df.loc[lp_df["mouse_ns"] == mouse_n]
            nrois = mouse_df["nrois"].unique()
            if len(nrois) != 1:
                raise RuntimeError(
                    "Each mouse in idx_stats_df should retain the same number "
                    " of ROIs across sessions.")
            
            mouse_data = np.full((len(sess_ns), nrois[0]), np.nan)
            for s, sess_n in enumerate(sess_ns):
                rows = mouse_df.loc[mouse_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    mouse_data[s] = rows.loc[rows.index[0], "roi_idxs"]
                elif len(rows) > 1:
                    raise RuntimeError(
                        "Expected 1 row per line/plane/session/mouse."
                        )
            lp_data.append(mouse_data)

        lp_data = np.concatenate(lp_data, axis=1)

        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, 
            zorder=-13
            )
        sub_ax.plot(sess_ns, lp_data, color=col, lw=2, alpha=0.3)
    
    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(
        ax, datatype="roi", xticks=sess_ns, ylab="", kind="reg"
        )

    for sub_ax in ax.reshape(-1):
        xticks = sub_ax.get_xticks()
        plot_util.set_ticks(
            sub_ax, "x", np.min(xticks), np.max(xticks), n=len(xticks), 
            pad_p=0.2
            )
    
    return ax
示例#21
0
def plot_rel_resp_data(rel_resp_df, analyspar, sesspar, stimpar, permpar, 
                       figpar, title=None, small=True):
    """
    plot_rel_resp_data((rel_resp_df, analyspar, sesspar, stimpar, permpar, 
                       figpar)

    Plots relative response errorbar data across sessions.

    Required args:
        - rel_resp_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - rel_reg or rel_exp (list): 
                data stats for regular or expected data (me, err)
            - rel_unexp (list): data stats for unexpected data (me, err)
            for reg/exp/unexp data types, session comparisons, e.g. 1v2:
            - {data_type}_raw_p_vals_{}v{} (float): uncorrected p-value for 
                data differences between sessions 
            - {data_type}_p_vals_{}v{} (float): p-value for data between 
                sessions, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - stimpar (dict):
            dictionary with keys of StimPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, subplots are smaller
            default: False

    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, rel_resp_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    if small:
        figpar["init"]["subplot_hei"] = 4.1
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"] = {"hspace": 0.2, "wspace": 0.25}
        figpar["init"]["gs"]["wspace"] = 0.25
    else:
        figpar["init"]["subplot_hei"] = 4.4
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"] = {"wspace": 0.3}

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    if stimpar["stimtype"] == "gabors":
        data_types = ["rel_reg", "rel_unexp"]
    elif stimpar["stimtype"] == "visflow":
        data_types = ["rel_exp", "rel_unexp"]
    else:
        gen_util.accepted_values_error(
            "stimpar['stimtype']", stimpar["stimtype"], ["gabors", "visflow"]
            )

    for (line, plane), lp_df in rel_resp_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        sess_indices = []
        lp_sess_ns = []
        for sess_n in sess_ns:
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) == 1:
                sess_indices.append(rows.index[0])
                lp_sess_ns.append(sess_n)
            elif len(rows) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")

        sub_ax.axhline(
            y=1, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, zorder=-13
            )

        colors = ["gray", col]
        fmts = ["-d", "-o"]
        alphas = [0.6, 0.8]
        ms = [12, None]
        for d, data_type in enumerate(data_types):
            data = np.asarray([lp_df.loc[i, data_type] for i in sess_indices])
            plot_util.plot_errorbars(
                sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=colors[d], 
                alpha=alphas[d], ms=ms[d], fmt=fmts[d], line_dash=dash)
            
    highest = None
    for dry_run in [True, False]: # to get correct data heights
        for data_type in data_types:
            ctrl = ("unexp" not in data_type)
            highest = add_between_sess_sig(
                ax, rel_resp_df, permpar, data_col=data_type, highest=highest, 
                ctrl=ctrl, p_val_prefix=True, dry_run=dry_run)
            if not dry_run:
                highest = [val * 1.05 for val in highest] # increment a bit
            

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        datatype="roi", sess_ns=sess_ns, kind="reg", xticks=sess_ns, 
        modif_share=False)

    return ax

    
示例#22
0
def plot_tracked_idx_stats(idx_stats_df, sesspar, figpar, permpar=None,
                           absolute=True, between_sess_sig=True, 
                           by_mouse=False, bootstr_err=None, title=None, 
                           wide=False):
    """
    plot_tracked_idx_stats(idx_stats_df, sesspar, figpar)

    Plots tracked ROI USI statistics.

    Required args:
        - idx_stats_df (pd.DataFrame):
            dataframe with one row per session, and the following columns, in 
            addition to the basic sess_df columns:
            - roi_idxs (list): index statistics
            or
            - abs_roi_idxs (list): absolute index statistics
        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple. Required if 
            between_sess_sig is True.
            default: None
        - absolute (bool):
            if True, data statistics are on absolute ROI indices
            default: True
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - bootstr_err (bool):
            if True, error is bootstrapped standard deviation
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False

    Returns:
        - ax (2D array): 
            array of subplots    
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_stats_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharey"] = "row"
    figpar["init"]["subplot_hei"] = 4.1
    figpar["init"]["subplot_wid"] = 2.6
    figpar["init"]["gs"] = {"wspace": 0.25, "hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")
    
    data_col = "roi_idx_stats"
    if absolute:
        data_col = f"abs_{data_col}"
    
    if data_col not in idx_stats_df.columns:
        raise KeyError(f"Expected to find {data_col} in idx_stats_df columns.")

    for (line, plane), lp_df in idx_stats_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5, 
            zorder=-13
            )

        mouse_ns = ["any"]
        mouse_cols = [col]
        if by_mouse:
            mouse_ns = sorted(lp_df["mouse_ns"].unique())
            mouse_cols = plot_util.get_hex_color_range(
                len(mouse_ns), col=col, 
                interval=plot_helper_fcts.MOUSE_COL_INTERVAL
                )

        for mouse_n, mouse_col in zip(mouse_ns, mouse_cols):
            sub_df = lp_df
            if by_mouse:
                sub_df = lp_df.loc[lp_df["mouse_ns"] == mouse_n]
            
            sess_indices = []
            sub_sess_ns = []

            for sess_n in sess_ns:
                rows = sub_df.loc[sub_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    sess_indices.append(rows.index[0])
                    sub_sess_ns.append(sess_n)

            data = np.asarray([sub_df.loc[i, data_col] for i in sess_indices])

            # plot errorbars
            alpha = 0.6 if by_mouse else 0.8
            capsize = 8 if bootstr_err else None
            plot_util.plot_errorbars(
                sub_ax, data[:, 0], data[:, 1:].T, sub_sess_ns, color=mouse_col, 
                alpha=alpha, xticks="auto", line_dash=dash, capsize=capsize,
                )

    if between_sess_sig:
        if permpar is None:
            raise ValueError(
                "If 'between_sess_sig' is True, must provide permpar."
                )
        if by_mouse:
            raise NotImplementedError(
                "Plotting between session statistical signifiance is not "
                "implemented if 'by_mouse' if True."
                )

        seq_plots.add_between_sess_sig(
            ax, idx_stats_df, permpar, data_col=data_col
            )
    
    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(
        ax, datatype="roi", xticks=sess_ns, ylab="", kind="reg"
        )
    
    return ax
def add_between_stim_sig(ax, sub_ax_all, data_df, permpar):
    """
    add_between_stim_sig(ax, sub_ax_all, data_df, permpar)

    Plot significance markers for significant comparisons between stimulus 
    types.

    Required args:
        - ax (plt Axis): 
            axis
        - sub_ax_all (plt subplot): 
            all line/plane data subplot
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_col} (list): data stats (me, err)
            for session comparisons, e.g. 1v2:
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
    """

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})

    stimtypes = ["gabors", "visflow"]
    stim_sig_str = f"Gabors vs visual flow: "
    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):

        if len(lp_df) != 1:
            raise RuntimeError("Expected 1 row per line/plane/session.") 
        row_idx = lp_df.index[0]   

        x = [0, 1]
        data = np.vstack(
            [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()]
            ).T
        y = data[0]
        err = data[1:]
        highest = np.max(y + err[-1])

        if line != "all" and plane != "all":
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]
            mark_rel_y = 0.18
        else:
            col = plot_helper_fcts.NEARBLACK
            linpla_name = "All"
            sub_ax = sub_ax_all
            all_data_max = np.concatenate(
                [data_df[stimtypes[0]].tolist(), 
                data_df[stimtypes[1]].tolist()], 
                axis=0
                )[:, 0].max()
            highest = np.max([data[0].max(), all_data_max])
            mark_rel_y = 0.15

        p_val = lp_df.loc[row_idx, "p_vals"]
        side = np.sign(y[1] - y[0])

        sig_str = misc_analys.get_sig_symbol(
            p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], 
            p_thresh=permpar["p_val"]
            )
        stim_sig_str = \
            f"{stim_sig_str}{TAB}{linpla_name}: {p_val:.5f}{sig_str:3}"
        
        if len(sig_str):
            plot_util.plot_barplot_signif(
                sub_ax, x, highest, rel_y=0.11, color=col, lw=3, 
                mark_rel_y=mark_rel_y, mark=sig_str, 
            )
    
    logger.info(stim_sig_str, extra={"spacing": TAB})
示例#24
0
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, 
                               title=None, seed=None):
    """
    plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar)

    Plots pupil and running block differences.

    Required args:
        - block_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_block_diffs (list): 
                running velocity differences per block
            - run_raw_p_vals (float):
                uncorrected p-value for differences within sessions
            - run_p_vals (float):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            - pupil_block_diffs (list): 
                for pupil diameter differences per block
            - pupil_raw_p_vals (list):
                uncorrected p-value for differences within sessions
            - pupil_p_vals (list):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - seed (int): 
            seed value to use. (-1 treated as None)
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if analyspar["scale"]:
        raise NotImplementedError(
            "Expected running and pupil data to not be scaled."
            )

    if len(block_df["sess_ns"].unique()) != 1:
        raise NotImplementedError(
            "'block_df' should only contain one session number."
        )

    nanpol = None if analyspar["rem_bad"] else "omit"
    
    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    datatypes = ["run", "pupil"]
    datatype_strs = ["Running velocity", "Pupil diameter"]
    n_datatypes = len(datatypes)

    fig, ax = plt.subplots(
        1, n_datatypes, figsize=(12.7, 4), squeeze=False, 
        gridspec_kw={"wspace": 0.22}
        )

    if title is not None:
        fig.suptitle(title, y=1.2, weight="bold")

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    corr_str = "corr." if permpar["multcomp"] else "raw"

    for d, datatype in enumerate(datatypes):
        datatype_sig_str = f"{datatype_strs[d]:16}:"
        sub_ax = ax[0, d]

        lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)]
        xs, all_data, cols, dashes, p_val_texts = [], [], [], [], []
        for (line, plane), lp_df in block_df.groupby(["lines", "planes"]):
            x, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane, flat=True
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(
                line, plane
                )
            lp_names[int(x)] = line_plane_name

            if len(lp_df) == 1:
                row_idx = lp_df.index[0]
            elif len(lp_df) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")
        
            lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"]
            
            # get p-value information
            p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"]

            side = np.sign(
                math_util.mean_med(
                    lp_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                )
            sig_str = misc_analys.get_sig_symbol(
                p_val_corr, sensitivity=sensitivity, side=side, 
                tails=permpar["tails"], p_thresh=permpar["p_val"]
                )

            p_val_text = f"{p_val_corr:.2f}{sig_str}"

            datatype_sig_str = (
                f"{datatype_sig_str}{TAB}{line_plane_name}: "
                f"{p_val_corr:.5f}{sig_str:3}"
                )

            # collect information
            xs.append(x)
            all_data.append(lp_data)
            cols.append(col)
            dashes.append(dash)
            p_val_texts.append(p_val_text)

        plot_violin_data(
            sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed
            )
        
        # edit ticks
        sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA))
        sub_ax.set_xticklabels(lp_names, fontweight="bold")
        sub_ax.tick_params(axis="x", which="both", bottom=False) 

        plot_util.expand_lims(sub_ax, axis="y", prop=0.1)

        plot_util.set_interm_ticks(
            np.asarray(sub_ax), n_ticks=3, axis="y", share=False, 
            fontweight="bold", update_ticks=True
            )

        for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)):
            ylim_range = np.diff(sub_ax.get_ylim())
            y = sub_ax.get_ylim()[1] + ylim_range * 0.08
            ha = "center"
            if d == 0 and i == 0:
                x += 0.2
                p_val_text = f"{corr_str} p-val. {p_val_text}"
                ha = "right"
            sub_ax.text(
                x, y, p_val_text, fontsize=20, weight="bold", ha=ha
                )

        logger.info(datatype_sig_str, extra={"spacing": TAB})

    # add labels/titles
    for d, datatype in enumerate(datatypes):
        sub_ax = ax[0, d]
        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5
            )
     
        if d == 0:
            ylabel = "Trial differences\nU-G - D-G"
            sub_ax.set_ylabel(ylabel, weight="bold")    
        
        if datatype == "run":
            title = "Running velocity (cm/s)"
        elif datatype == "pupil":
            title = "Pupil diameter (mm)"

        sub_ax.set_title(title, weight="bold", y=1.2)
        
    return ax