def get_sess_grped_diffs_df(sessions,
                            analyspar,
                            stimpar,
                            basepar,
                            permpar,
                            split="by_exp",
                            randst=None,
                            parallel=False):
    """
    get_sess_grped_diffs_df(sessions, analyspar, stimpar, basepar)

    Returns split difference statistics for specific sessions, grouped across 
    mice.

    Required args:
        - sessions (list): 
            session objects
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - basepar (BasePar): 
            named tuple containing baseline parameters
        - permpar (PermPar): 
            named tuple containing permutation parameters

    Optional args:
        - split (str): 
            how to split data:
            "by_exp" (all exp, all unexp), 
            "unexp_lock" (unexp, preceeding exp), 
            "exp_lock" (exp, preceeding unexp),
            "stim_onset" (grayscr, stim on), 
            "stim_offset" (stim off, grayscr)
            default: "by_exp"
        - randst (int or np.random.RandomState): 
            random state or seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - diffs_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - diff_stats (list): split difference stats (me, err)
            - null_CIs (list): adjusted null CI for split differences 
            - raw_p_vals (float): uncorrected p-value for differences within 
                sessions
            - p_vals (float): p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for differences
                between sessions 
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails
    """

    nanpol = None if analyspar.rem_bad else "omit"

    if analyspar.tracked:
        misc_analys.check_sessions_complete(sessions, raise_err=True)

    sess_diffs_df = misc_analys.get_check_sess_df(sessions, None, analyspar)
    initial_columns = sess_diffs_df.columns.tolist()

    # retrieve ROI index information
    args_dict = {
        "analyspar": analyspar,
        "stimpar": stimpar,
        "basepar": basepar,
        "split": split,
        "return_data": True,
    }

    # sess x split x ROI
    split_stats, split_data = gen_util.parallel_wrap(get_sess_roi_split_stats,
                                                     sessions,
                                                     args_dict=args_dict,
                                                     parallel=parallel,
                                                     zip_output=True)

    misc_analys.get_check_sess_df(sessions, sess_diffs_df)
    sess_diffs_df["roi_split_stats"] = list(split_stats)
    sess_diffs_df["roi_split_data"] = list(split_data)

    columns = initial_columns + ["diff_stats", "null_CIs"]
    diffs_df = pd.DataFrame(columns=columns)

    group_columns = ["lines", "planes", "sess_ns"]
    aggreg_cols = [col for col in initial_columns if col not in group_columns]
    for lp_grp_vals, lp_grp_df in sess_diffs_df.groupby(["lines", "planes"]):
        lp_grp_df = lp_grp_df.sort_values(["sess_ns", "mouse_ns"])
        line, plane = lp_grp_vals
        lp_name = plot_helper_fcts.get_line_plane_name(line, plane)
        logger.info(f"Running permutation tests for {lp_name} sessions...",
                    extra={"spacing": TAB})

        # obtain ROI random split differences per session
        # done here to avoid OOM errors
        lp_rand_diffs = gen_util.parallel_wrap(
            get_rand_split_data,
            lp_grp_df["roi_split_data"].tolist(),
            args_list=[analyspar, permpar, randst],
            parallel=parallel,
            zip_output=False)

        sess_diffs = []
        row_indices = []
        sess_ns = sorted(lp_grp_df["sess_ns"].unique())
        for sess_n in sess_ns:
            row_idx = len(diffs_df)
            row_indices.append(row_idx)
            sess_grp_df = lp_grp_df.loc[lp_grp_df["sess_ns"] == sess_n]

            grp_vals = list(lp_grp_vals) + [sess_n]
            for g, group_column in enumerate(group_columns):
                diffs_df.loc[row_idx, group_column] = grp_vals[g]

            # add aggregated values for initial columns
            diffs_df = misc_analys.aggreg_columns(sess_grp_df,
                                                  diffs_df,
                                                  aggreg_cols,
                                                  row_idx=row_idx,
                                                  in_place=True)

            # group ROI split stats across mice: split x ROIs
            split_stats = np.concatenate(
                sess_grp_df["roi_split_stats"].to_numpy(), axis=-1)

            # take diff and stats across ROIs
            diffs = split_stats[1] - split_stats[0]
            diff_stats = math_util.get_stats(diffs,
                                             stats=analyspar.stats,
                                             error=analyspar.error,
                                             nanpol=nanpol)
            diffs_df.at[row_idx, "diff_stats"] = diff_stats.tolist()
            sess_diffs.append(diffs)

            # group random ROI split diffs across mice, and take stat
            rand_idxs = [
                lp_grp_df.index.tolist().index(idx)
                for idx in sess_grp_df.index
            ]
            rand_diffs = math_util.mean_med(np.concatenate(
                [lp_rand_diffs[r] for r in rand_idxs], axis=0),
                                            axis=0,
                                            stats=analyspar.stats,
                                            nanpol=nanpol)

            # get CIs and p-values
            p_val, null_CI = rand_util.get_p_val_from_rand(
                diff_stats[0],
                rand_diffs,
                return_CIs=True,
                p_thresh=permpar.p_val,
                tails=permpar.tails,
                multcomp=permpar.multcomp,
                nanpol=nanpol)
            diffs_df.loc[row_idx, "p_vals"] = p_val
            diffs_df.at[row_idx, "null_CIs"] = null_CI

        del lp_rand_diffs  # free up memory

        # calculate p-values between sessions (0-1, 0-2, 1-2...)
        p_vals = rand_util.comp_vals_acr_groups(sess_diffs,
                                                n_perms=permpar.n_perms,
                                                stats=analyspar.stats,
                                                paired=analyspar.tracked,
                                                nanpol=nanpol,
                                                randst=randst)
        p = 0
        for i, sess_n in enumerate(sess_ns):
            for j, sess_n2 in enumerate(sess_ns[i + 1:]):
                key = f"p_vals_{int(sess_n)}v{int(sess_n2)}"
                diffs_df.loc[row_indices[i], key] = p_vals[p]
                diffs_df.loc[row_indices[j + 1], key] = p_vals[p]
                p += 1

    # add corrected p-values
    diffs_df = misc_analys.add_corr_p_vals(diffs_df, permpar)

    diffs_df["sess_ns"] = diffs_df["sess_ns"].astype(int)

    return diffs_df
示例#2
0
def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, 
                   between_sess_sig=True, data_col="diff_stats", 
                   decoder_data=False, title=None, wide=False):
    """
    plot_sess_data(data_df, analyspar, sesspar, permpar, figpar)

    Plots errorbar data across sessions.

    Required args:
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_key} (list): data stats (me, err)
            - null_CIs (list): adjusted null CI for data
            - raw_p_vals (float): uncorrected p-value for data within 
                sessions
            - p_vals (float): p-value for data within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for data 
                differences between sessions 
            - p_vals_{}v{} (float): p-value for data between sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True         
        - data_col (str):
            dataframe column in which data to plot is stored
            default: "diff_stats"
        - decoder_data (bool):
            if True, data plotted is decoder data
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False
        
    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, data_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)
    
    sharey = True if decoder_data else "row"
    figpar["init"]["sharey"] = sharey
    figpar["init"]["subplot_hei"] = 4.4
    figpar["init"]["gs"] = {"hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"]["wspace"] = 0.3
    else:
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.97, weight="bold")

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    for pass_n in [0, 1]: # add significance markers on the second pass
        if pass_n == 1:
            logger.info(f"{comp_info}:", extra={"spacing": "\n"})
            
        for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            sess_indices = []
            lp_sess_ns = []
            for sess_n in sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    sess_indices.append(rows.index[0])
                    lp_sess_ns.append(sess_n)
                elif len(rows) > 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")

            data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices])

            if pass_n == 0:
                # plot errorbars
                plot_util.plot_errorbars(
                    sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, 
                    alpha=0.8, xticks="auto", line_dash=dash
                    )

            if pass_n == 1:
                # plot CIs
                CIs = np.asarray(
                    [lp_df.loc[i, "null_CIs"] for i in sess_indices]
                    )
                CI_meds = CIs[:, 1]
                CIs = CIs[:, np.asarray([0, 2])]

                plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, 
                    width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, 
                    zorder=-12)

                # add significance markers within sessions
                y_maxes = data[:, 0] + data[:, -1]
                sides = [
                    np.sign(sub[0] - CI_med) 
                    for sub, CI_med in zip(data, CI_meds)
                    ]
                p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices]
                lp_sig_str = f"{line_plane_name:6} (within session):"
                for s, sess_n in enumerate(lp_sess_ns):
                    sig_str = misc_analys.get_sig_symbol(
                        p_vals_corr[s], sensitivity=sensitivity, side=sides[s], 
                        tails=permpar["tails"], p_thresh=permpar["p_val"]
                        )

                    if len(sig_str):
                        plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], 
                            rel_y=0.15, color=col, mark=sig_str)  

                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n}: "
                        f"{p_vals_corr[s]:.5f}{sig_str:3}"
                        )

                logger.info(lp_sig_str, extra={"spacing": TAB})
        
    if between_sess_sig:
        add_between_sess_sig(ax, data_df, permpar, data_col=data_col)

    area, ylab = True, None
    if decoder_data:
        area = False
        if "balanced" in data_col:
            ylab = "Balanced accuracy (%)" 
        else:
            ylab = "Accuracy %"

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", 
        xticks=sess_ns, modif_share=False)

    return ax
def add_between_stim_sig(ax, sub_ax_all, data_df, permpar):
    """
    add_between_stim_sig(ax, sub_ax_all, data_df, permpar)

    Plot significance markers for significant comparisons between stimulus 
    types.

    Required args:
        - ax (plt Axis): 
            axis
        - sub_ax_all (plt subplot): 
            all line/plane data subplot
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_col} (list): data stats (me, err)
            for session comparisons, e.g. 1v2:
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
    """

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})

    stimtypes = ["gabors", "visflow"]
    stim_sig_str = f"Gabors vs visual flow: "
    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):

        if len(lp_df) != 1:
            raise RuntimeError("Expected 1 row per line/plane/session.") 
        row_idx = lp_df.index[0]   

        x = [0, 1]
        data = np.vstack(
            [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()]
            ).T
        y = data[0]
        err = data[1:]
        highest = np.max(y + err[-1])

        if line != "all" and plane != "all":
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]
            mark_rel_y = 0.18
        else:
            col = plot_helper_fcts.NEARBLACK
            linpla_name = "All"
            sub_ax = sub_ax_all
            all_data_max = np.concatenate(
                [data_df[stimtypes[0]].tolist(), 
                data_df[stimtypes[1]].tolist()], 
                axis=0
                )[:, 0].max()
            highest = np.max([data[0].max(), all_data_max])
            mark_rel_y = 0.15

        p_val = lp_df.loc[row_idx, "p_vals"]
        side = np.sign(y[1] - y[0])

        sig_str = misc_analys.get_sig_symbol(
            p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], 
            p_thresh=permpar["p_val"]
            )
        stim_sig_str = \
            f"{stim_sig_str}{TAB}{linpla_name}: {p_val:.5f}{sig_str:3}"
        
        if len(sig_str):
            plot_util.plot_barplot_signif(
                sub_ax, x, highest, rel_y=0.11, color=col, lw=3, 
                mark_rel_y=mark_rel_y, mark=sig_str, 
            )
    
    logger.info(stim_sig_str, extra={"spacing": TAB})
示例#4
0
def plot_idx_corr_scatterplots(idx_corr_df,
                               permpar,
                               figpar,
                               permute="sess",
                               title=None):
    """
    plot_idx_corr_scatterplots(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlation scatterplots for individual session 
    comparisons.

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for correlation data (normalized if corr_type is "diff_corr") for 
            session comparisons (x, y), e.g. 1v2
            - binned_rand_stats (list): number of random correlation values per 
                bin (xs x ys)
            - corr_data_xs (list): USI values for x
            - corr_data_ys (list): USI values for y
            - corrs (float): correlation between session data (x and y)
            - p_vals (float): p-value for correlation, corrected for 
                multiple comparisons and tails
            - rand_corr_meds (float): median of the random correlations
            - raw_p_vals (float): p-value for intersession correlations
            - regr_coefs (float): regression correlation coefficient (slope)
            - regr_intercepts (float): regression correlation intercept
            - x_bin_mids (list): x mid point for each random correlation bin
            - y_bin_mids (list): y mid point for each random correlation bin

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    diffs = False
    if permute in ["sess", "all"]:
        diffs = True

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 4
    figpar["init"]["subplot_wid"] = 4
    figpar["init"]["gs"] = {"hspace": 0.4, "wspace": 0.4}

    fig, ax = plot_util.init_fig(4, **figpar["init"])

    if title is not None:
        fig.suptitle(title, fontweight="bold", y=0.97)

    sess_ns = None

    # first pass to plot
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        sub_ax = ax[pl, li]

        if len(lp_df) != 1:
            raise RuntimeError("Expected exactly one row.")
        lp_row = lp_df.loc[lp_df.index[0]]

        if sess_ns is None:
            sess_ns = lp_row["sess_ns"]
            xlabel = f"Session {sess_ns[0]} USIs"
            ylabel = f"Session {sess_ns[1]} USIs"
            if diffs:
                ylabel = f"Session {sess_ns[1]} - {sess_ns[0]} USIs"

        elif sess_ns != lp_row["sess_ns"]:
            raise RuntimeError("Expected all sess_ns to match.")

        density_data = [
            lp_row["x_bin_mids"], lp_row["y_bin_mids"],
            np.asarray(lp_row["binned_rand_stats"]).T
        ]
        sub_ax.contour(*density_data,
                       levels=6,
                       cmap="Greys",
                       zorder=-13,
                       linewidths=4)

        alpha = 0.3**(len(lp_row["corr_data_xs"]) / 300)
        sub_ax.scatter(lp_row["corr_data_xs"],
                       lp_row["corr_data_ys"],
                       color=col,
                       alpha=alpha,
                       lw=2,
                       s=35)

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         xticks=None,
                                         ylab=ylabel,
                                         xlab=xlabel,
                                         kind="reg")

    # second pass to add plot markings
    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    sig_str = ""
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
        sub_ax = ax[pl, li]

        # add markers back in (removed due to kind='reg')
        sub_ax.tick_params(axis="x", which="both", bottom=True, top=False)
        sub_ax.spines["bottom"].set_visible(True)

        lp_row = lp_df.loc[lp_df.index[0]]

        p_val_corr = lp_row["p_vals"]
        lp_sig_str = add_scatterplot_markers(sub_ax,
                                             permpar,
                                             lp_row["corrs"],
                                             lp_row["rand_corr_meds"],
                                             lp_row["regr_coefs"],
                                             lp_row["regr_intercepts"],
                                             lp_row["p_vals"],
                                             col=col,
                                             diffs=diffs)

        sig_str = (f"{sig_str}{TAB}{line_plane_name}: "
                   f"{p_val_corr:.5f}{lp_sig_str:3}")

    logger.info(sig_str, extra={"spacing": TAB})

    return ax
示例#5
0
def add_between_sess_sig(ax, data_df, permpar, data_col="diff_stats", 
                         highest=None, ctrl=False, p_val_prefix=False, 
                         dry_run=False):
    """
    add_between_sess_sig(ax, data_df, permpar)

    Plot significance markers for significant comparisons between sessions.

    Required args:
        - ax (plt Axis): 
            axis
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_col} (list): data stats (me, err)
            for session comparisons, e.g. 1v2:
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple

    Optional args:
        - data_col (str):
            data column name in data_df
            default: "diff_stats"
        - highest (list):
            highest point for each line/plane, in order
            default: None
        - ctrl (bool): 
            if True, significance symbols should use control colour and symbol
            default: False
        - p_val_prefix (bool):
            if True, p-value columns start with data_col as a prefix 
            "{data_col}_p_vals_{}v{}".
            default: False
        - dry_run (bool):
            if True, a dry-run is done to get highest values, but nothing is 
            plotted or logged.
            default: False

    Returns:
    - highest (list):
        highest point for each line/plane, in order, after plotting
    """

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    prefix = f"{data_col}_" if p_val_prefix else ""

    if not dry_run:
        logger.info(f"{comp_info}:", extra={"spacing": "\n"})

    for pass_n in [0, 1]: # add significance markers on the second pass
        linpla_grps = list(data_df.groupby(["lines", "planes"]))
        if highest is None:
            highest = [0] * len(linpla_grps)
        elif len(highest) != len(linpla_grps):
            raise ValueError("If highest is provided, it must contain as "
                "many values as line/plane groups in data_df.")

        for l, ((line, plane), lp_df) in enumerate(linpla_grps):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            if ctrl:
                col = "gray"

            lp_sess_ns = np.sort(lp_df["sess_ns"].unique())
            for sess_n in lp_sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) != 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")    

            sig_p_vals, sig_strs, sig_xs = [], [], []
            lp_sig_str = f"{line_plane_name:6} (between sessions):"
            for i, sess_n1 in enumerate(lp_sess_ns):
                row_1s = lp_df.loc[lp_df["sess_ns"] == sess_n1]
                for sess_n2 in lp_sess_ns[i + 1: ]:
                    row_2s = lp_df.loc[lp_df["sess_ns"] == sess_n2]
                    if len(row_1s) != 1 or len(row_2s) != 1:
                        raise RuntimeError(
                            "Expected exactly one row per session."
                            )
                    row1 = row_1s.loc[row_1s.index[0]]
                    row2 = row_2s.loc[row_2s.index[0]]

                    row1_highest = row1[data_col][0] + row1[data_col][-1]
                    row2_highest = row2[data_col][0] + row2[data_col][-1]      
                    highest[l] = np.nanmax(
                        [highest[l], row1_highest, row2_highest]
                        )
                    
                    if dry_run:
                        continue
                    
                    p_val = row1[
                        f"{prefix}p_vals_{int(sess_n1)}v{int(sess_n2)}"
                        ]
                    side = np.sign(row2[data_col][0] - row1[data_col][0])

                    sig_str = misc_analys.get_sig_symbol(
                        p_val, sensitivity=sensitivity, side=side, 
                        tails=permpar["tails"], p_thresh=permpar["p_val"], 
                        ctrl=ctrl
                        )

                    if len(sig_str):
                        sig_p_vals.append(p_val)
                        sig_strs.append(sig_str)
                        sig_xs.append([sess_n1, sess_n2])
                    
                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n1}v{sess_n2}: "
                        f"{p_val:.5f}{sig_str:3}"
                        )

            if dry_run:
                continue

            n = len(sig_p_vals)
            ylims = sub_ax.get_ylim()
            prop = np.diff(ylims)[0] / 8.0

            if pass_n == 0:
                logger.info(lp_sig_str, extra={"spacing": TAB})
                if n == 0:
                    continue

                # count number of significant comparisons, and adjust y limits
                ylims = [
                    ylims[0], np.nanmax(
                        [ylims[1], highest[l] + prop * (n + 1)])
                        ]
                sub_ax.set_ylim(ylims)
            else:
                if n == 0:
                    continue

                if ctrl:
                    mark_rel_y = 0.22
                    fontsize = 14
                else:
                    mark_rel_y = 0.18
                    fontsize = 20

                # add significance markers sequentially, on second pass
                y_pos = highest[l]
                for s, (p_val, sig_str, sig_x) in enumerate(
                    zip(sig_p_vals, sig_strs, sig_xs)
                ):  
                    y_pos = highest[l] + (s + 1) * prop
                    plot_util.plot_barplot_signif(
                        sub_ax, sig_x, y_pos, rel_y=0.11, color=col, lw=3, 
                        mark_rel_y=mark_rel_y, mark=sig_str, fontsize=fontsize
                        )
                highest[l] = np.nanmax([highest[l], y_pos])

                if y_pos > ylims[1]:
                    sub_ax.set_ylim([ylims[0], y_pos * 1.1])

    return highest
示例#6
0
def plot_idx_correlations(idx_corr_df,
                          permpar,
                          figpar,
                          permute="sess",
                          corr_type="corr",
                          title=None,
                          small=True):
    """
    plot_idx_correlations(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlations across sessions.

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for session comparisons, e.g. 1v2
            - {}v{}{norm_str}_corrs (float): intersession ROI index correlations
            - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI 
                index correlation standard deviation
            - {}v{}_null_CIs (list): adjusted null CI for intersession ROI 
                index correlations
            - {}v{}_raw_p_vals (float): p-value for intersession correlations
            - {}v{}_p_vals (float): p-value for intersession correlations, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - corr_type (str):
            type of correlation run, i.e. "corr" or "R_sqr"
            default: "corr"
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, smaller subplots are plotted
            default: True

    Returns:
        - ax (2D array): 
            array of subplots
    """

    norm = False
    if permute in ["sess", "all"]:
        corr_type = f"diff_{corr_type}"
        if corr_type == "diff_corr":
            norm = True
            title = title.replace("Correlations", "Normalized correlations")

    norm_str = "_norm" if norm else ""

    sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm)
    n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2)  # multiple of 2

    figpar = sess_plot_util.fig_init_linpla(figpar,
                                            kind="reg",
                                            n_sub=int(n_pairs / 2))

    figpar["init"]["ncols"] = n_pairs
    figpar["init"]["sharey"] = "row"

    figpar["init"]["gs"] = {"hspace": 0.25}
    if small:
        figpar["init"]["subplot_wid"] = 2.7
        figpar["init"]["subplot_hei"] = 4.21
        figpar["init"]["gs"]["wspace"] = 0.2
    else:
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["subplot_hei"] = 4.71
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm)
    lines = [None, None]

    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
        lines[li] = line.split("-")[0].replace("23", "2/3")

        if len(lp_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        row = lp_df.loc[lp_df.index[0]]

        lp_sig_str = f"{linpla_name:6}:"
        for s, sess_pair in enumerate(sess_pairs):
            sub_ax = ax[pl, s]
            if s == 0:
                sub_ax.set_ylim(plane_pts[pl])

            col_base = f"{sess_pair[0]}v{sess_pair[1]}"

            CI = row[f"{col_base}_null_CIs"]
            extr = np.asarray([CI[0], CI[2]])
            plot_util.plot_CI(sub_ax,
                              extr,
                              med=CI[1],
                              x=li,
                              width=0.45,
                              med_rat=0.025)

            y = row[f"{col_base}{norm_str}_corrs"]

            err = row[f"{col_base}{norm_str}_corr_stds"]
            plot_util.plot_ufo(sub_ax,
                               x=li,
                               y=y,
                               err=err,
                               color=col,
                               capsize=8)

            # add significance markers
            p_val = row[f"{col_base}_p_vals"]
            side = np.sign(y - CI[1])
            sensitivity = misc_analys.get_sensitivity(permpar)
            sig_str = misc_analys.get_sig_symbol(p_val,
                                                 sensitivity=sensitivity,
                                                 side=side,
                                                 tails=permpar["tails"],
                                                 p_thresh=permpar["p_val"])

            if len(sig_str):
                high = np.max([CI[-1], y + err])
                plot_util.add_signif_mark(sub_ax,
                                          li,
                                          high,
                                          rel_y=0.1,
                                          color=col,
                                          fontsize=24,
                                          mark=sig_str)

            sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: "
            lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}"

        logger.info(lp_sig_str, extra={"spacing": TAB})

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         lines=["", ""],
                                         xticks=[0, 1],
                                         ylab="",
                                         xlab="",
                                         kind="traces")

    xs = np.arange(len(lines))
    pad_x = 0.6 * (xs[1] - xs[0])
    for row_n in range(len(ax)):
        for col_n in range(len(ax[row_n])):
            sub_ax = ax[row_n, col_n]
            sub_ax.tick_params(axis="x", which="both", bottom=False)
            sub_ax.set_xticks(xs)
            sub_ax.set_xticklabels(lines, weight="bold")
            sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x])

            sub_ax.set_ylim(plane_pts[row_n])
            sub_ax.set_yticks(plane_pts[row_n])

            if row_n == 0:
                if col_n < len(sess_pairs):
                    s1, s2 = sess_pairs[col_n]
                    sess_pair_title = f"Session {s1} v {s2}"
                    sub_ax.set_title(sess_pair_title,
                                     fontweight="bold",
                                     y=1.07)
                sub_ax.spines["bottom"].set_visible(True)

        plot_util.set_interm_ticks(ax[row_n],
                                   3,
                                   axis="y",
                                   weight="bold",
                                   share=False,
                                   update_ticks=True)

    return ax
示例#7
0
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, 
                       title=None):
    """
    plot_perc_sig_usis(perc_sig_df, analyspar, figpar)

    Plots percentage of significant USIs.

    Required args:
        - perc_sig_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns:
            for sig in ["lo", "hi"]: for low vs high ROI indices
            - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100)
            - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation 
                over percent significant ROIs
            - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs 
            - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent 
                sig. ROIs
            - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for 
                percent sig. ROIs
            - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. 
                ROIs, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - title (str):
            plot title
            default: None
        
    Returns:
        - ax (2D array): 
            array of subplots
    """  

    perc_sig_df = perc_sig_df.copy(deep=True)

    nanpol = None if analyspar["rem_bad"] else "omit"

    sess_ns = perc_sig_df["sess_ns"].unique()
    if len(sess_ns) != 1:
        raise NotImplementedError(
            "Plotting function implemented for 1 session only."
            )

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, 
        sharex=True, sharey=True)

    figpar["init"]["sharey"] = True
    figpar["init"]["subplot_wid"] = 3.4
    figpar["init"]["gs"] = {"wspace": 0.18}
    if by_mouse:
        figpar["init"]["subplot_hei"] = 8.4
    else:
        figpar["init"]["subplot_hei"] = 3.5

    fig, ax = plot_util.init_fig(2, **figpar["init"])
    if title is not None:
        y = 0.98 if by_mouse else 1.07
        fig.suptitle(title, y=y, weight="bold")

    tail_order = ["Low tail", "High tail"]
    tail_keys = ["lo", "hi"]
    chance = permpar["p_val"] / 2 * 100

    ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40)
    n_linpla = plot_helper_fcts.N_LINPLA

    comp_info = misc_analys.get_comp_info(permpar)
    
    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for t, (tail, key) in enumerate(zip(tail_order, tail_keys)):
        sub_ax = ax[0, t]
        sub_ax.set_title(tail, fontweight="bold")
        sub_ax.set_ylim(ylims)

        # replace bottom spine with line at 0
        sub_ax.spines['bottom'].set_visible(False)
        sub_ax.axhline(y=0, c="k", lw=4.0)

        data_key = f"perc_sig_{key}_idxs"

        CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan)
        CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan)

        tail_sig_str = f"{tail:9}:"
        linpla_names = []
        for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            x_index = 2 * li + pl
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            linpla_names.append(linpla_name)
            
            if len(lp_df) == 0:
                continue
            elif len(lp_df) > 1 and not by_mouse:
                raise RuntimeError("Expected a single row per line/plane.")

            lp_df = lp_df.sort_values("mouse_ns") # sort by mouse
            df_indices = lp_df.index.tolist()

            if by_mouse:
                # plot means or medians per mouse
                mouse_data = lp_df[data_key].to_numpy()
                mouse_cols = plot_util.get_hex_color_range(
                    len(lp_df), col=col, 
                    interval=plot_helper_fcts.MOUSE_COL_INTERVAL
                    )
                mouse_data_mean = math_util.mean_med(
                    mouse_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                CI_dummy = np.repeat(mouse_data_mean, 2)
                plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, 
                    x=x_index, width=0.6, med_col=col, med_rat=0.01)
            else:
                # collect confidence interval data
                row = lp_df.loc[df_indices[0]]
                mouse_cols = [col]
                CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[
                    np.asarray([0, 2])
                    ]
                CI_meds[x_index] = row[f"{data_key}_null_CIs"][1]

            if by_mouse:
                perc_p_vals = []
                rel_y = 0.05
            else:
                tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: "
                rel_y = 0.1

            for df_i, mouse_col in zip(df_indices, mouse_cols):
                # plot UFOs
                err = None
                no_line = True
                if not by_mouse:
                    err = perc_sig_df.loc[df_i, f"{data_key}_stds"]
                    no_line = False
                # indicate bootstrapped error with wider capsize
                plot_util.plot_ufo(
                    sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err,
                    color=mouse_col, capsize=8, no_line=no_line
                    )

                # add significance markers
                p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"]
                perc = perc_sig_df.loc[df_i, data_key]
                nrois = np.sum(perc_sig_df.loc[df_i, "nrois"])
                side = np.sign(perc - chance)
                sensitivity = misc_analys.get_binom_sensitivity(
                    nrois, null_perc=chance, side=side
                    )                
                sig_str = misc_analys.get_sig_symbol(
                    p_val, sensitivity=sensitivity, side=side, 
                    tails=permpar["tails"], p_thresh=permpar["p_val"]
                    )

                if len(sig_str):
                    perc_high = perc + err if err is not None else perc
                    plot_util.add_signif_mark(sub_ax, x_index, perc_high, 
                        rel_y=rel_y, color=mouse_col, fontsize=24, 
                        mark=sig_str) 

                if by_mouse:
                    perc_p_vals.append(
                        (int(np.around(perc)), p_val, sig_str)
                    )
                else:
                    tail_sig_str = (
                        f"{tail_sig_str}{p_val:.5f}{sig_str:3}"
                        )

            if by_mouse: # sort p-value logging by percentage value
                tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: "
                order = np.argsort([vals[0] for vals in perc_p_vals])
                for i in order:
                    perc, p_val, sig_str = perc_p_vals[i]
                    perc_str = f"(~{perc}%)"
                    tail_sig_str = (
                        f"{tail_sig_str}{TAB}{perc_str:6} "
                        f"{p_val:.5f}{sig_str:3}"
                        )
                
        # add chance information
        if by_mouse:
            sub_ax.axhline(
                y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, 
                zorder=-12
                )
        else:
            plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, 
                x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12)

        logger.info(tail_sig_str, extra={"spacing": TAB})
    
    for sub_ax in fig.axes:
        sub_ax.tick_params(axis="x", which="both", bottom=False) 
        plot_util.set_ticks(
            sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2)
        sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold")

    ax[0, 0].set_ylabel("%", fontweight="bold")
    plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True)

    # adjustment if tick interval is repeated in the negative
    if ax[0, 0].get_ylim()[0] < 0:
        ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]])

    return ax
示例#8
0
def plot_roi_tracking(roi_mask_df, figpar, title=None):
    """
    plot_roi_tracking(roi_mask_df, figpar)
    
    Plots ROI tracking examples, for different session permutations, and union 
    across permutations.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)
            - "union_n_conflicts" (int): number of conflicts after union
            for "union", "fewest" and "most" tracked ROIs:
            - "{}_registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val),
                ordered by {}_sess_ns if "fewest" or "most"
            - "{}_n_tracked" (int): number of tracked ROIs
            for "fewest", "most" tracked ROIs:
            - "{}_sess_ns" (list): ordered session number 
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if len(roi_mask_df) != 1:
        raise ValueError("Expected only one row in roi_mask_df")
    roi_mask_row = roi_mask_df.loc[roi_mask_df.index[0]]

    columns = ["fewest", "most", "", "union"]

    figpar["init"]["ncols"] = len(columns)
    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 5.05
    figpar["init"]["subplot_wid"] = 5.05
    figpar["init"]["gs"] = {"wspace": 0.06}

    # MUST ADJUST if anything above changes [right, bottom, width, height]
    new_axis_coords = [0.905, 0.125, 0.06, 0.74] 

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    sub_ax_scale = fig.add_axes(new_axis_coords)
    plot_util.remove_axis_marks(sub_ax_scale)
    sub_ax_scale.spines["left"].set_visible(True)

    if title is not None:
        fig.suptitle(title, y=1.05, weight="bold")

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    for c, column in enumerate(columns):
        sub_ax = ax[0, c]

        if c == 0:
            lp_col = plot_helper_fcts.get_line_plane_idxs(
                roi_mask_row["lines"], roi_mask_row["planes"]
                )[2]

            lp_name = plot_helper_fcts.get_line_plane_name(
                roi_mask_row["lines"], roi_mask_row["planes"]
                )
            sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
            log_info = f"Conflicts and matches for a {lp_name} example:"

        if column == "":
            sub_ax.set_axis_off()
            subplot_title = \
                "     Union - conflicts\n...   ====================>"
            sub_ax.set_title(subplot_title, fontweight="bold", y=0.5)
            continue
        else:
            plot_util.remove_axis_marks(sub_ax)
            for spine in ["right", "left", "top", "bottom"]:
                sub_ax.spines[spine].set_visible(True)

        if column in ["fewest", "most"]:
            y = 1.01
            ord_sess_ns = roi_mask_row[f"{column}_sess_ns"]
            ord_sess_ns_str = ", ".join([str(n) for n in ord_sess_ns])

            n_matches = int(roi_mask_row[f"{column}_n_tracked"])
            subplot_title = f"{n_matches} matches\n(sess {ord_sess_ns_str})"
            log_info = (f"{log_info}\n{TAB}"
                f"{column.capitalize()} matches (sess {ord_sess_ns_str}): "
                f"{n_matches}")
        
        elif column == "union":
            y = 1.04
            ord_sess_ns = roi_mask_row["sess_ns"]
            n_union = int(roi_mask_row[f"{column}_n_tracked"])
            n_conflicts = int(roi_mask_row[f"{column}_n_conflicts"])
            n_matches = n_union - n_conflicts

            subplot_title = f"{n_matches} matches"
            log_info = (f"{log_info}\n{TAB}"
                "Union - conflicts: "
                f"{n_union} - {n_conflicts} = {n_matches} matches"
                )

        sub_ax.set_title(subplot_title, fontweight="bold", y=y)

        roi_masks = create_sess_roi_masks(
            roi_mask_row, 
            mask_key=f"{column}_registered_roi_mask_idxs"
            )
        
        for sess_n in roi_mask_row["sess_ns"]:
            col = sess_cols[int(sess_n)]
            s = ord_sess_ns.index(sess_n)
            add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha)

    # add scale marker
    hei_len = roi_mask_row["roi_mask_shapes"][1]
    add_scale_marker(
        sub_ax_scale, side_len=hei_len, ori="vertical", quadrant=3, fontsize=20
        )

    logger.info(log_info, extra={"spacing": "\n"})

    # add legend
    add_sess_col_leg(
        ax[0, columns.index("")], 
        sess_cols, 
        bbox_to_anchor=(0.67, 0.3), 
        alpha=alpha
        )

    return ax
示例#9
0
def plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar, title=None):
    """
    plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar)

    Plots ROI mask contours overlayed over imaging planes, and ROI masks 
    overlayed over each other across sessions.

    Required args:
        - roi_mask_df (pd.DataFrame in dict format):
            dataframe with a row for each mouse, and the following 
            columns, in addition to the basic sess_df columns: 

            - "max_projections" (list): pixel intensities of maximum projection 
                for the plane (hei x wid)
            - "registered_roi_mask_idxs" (list): list of mask indices, 
                registered across sessions, for each session 
                (flattened across ROIs) ((sess, hei, wid) x val)
            - "roi_mask_idxs" (list): list of mask indices for each session, 
                and each ROI (sess x (ROI, hei, wid) x val) (not registered)
            - "roi_mask_shapes" (list): shape into which ROI mask indices index 
                (sess x hei x wid)

            - "crop_fact" (num): factor by which to crop masks (> 1) 
            - "shift_prop_hei" (float): proportion by which to shift cropped 
                mask center vertically from left edge [0, 1]
            - "shift_prop_wid" (float): proportion by which to shift cropped 
                mask center horizontally from left edge [0, 1]

        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters  

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    n_lines = len(roi_mask_df["lines"].unique())
    n_planes = len(roi_mask_df["planes"].unique())

    sess_cols = get_sess_cols(roi_mask_df)
    n_sess = len(sess_cols)
    n_cols = n_sess * n_lines

    figpar = sess_plot_util.fig_init_linpla(figpar)

    figpar["init"]["sharex"] = False
    figpar["init"]["sharey"] = False
    figpar["init"]["subplot_hei"] = 2.3
    figpar["init"]["subplot_wid"] = 2.3
    figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.2}
    figpar["init"]["ncols"] = n_cols

    fig, ax = plot_util.init_fig(n_cols * n_planes * 2, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.93, weight="bold")

    crop = "crop_fact" in roi_mask_df.columns

    sess_cols = get_sess_cols(roi_mask_df)
    alpha = 0.6
    raster_zorder = -12

    for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]):
        li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        lp_col = plot_helper_fcts.get_line_plane_idxs(line, plane)[2]
        lp_name = plot_helper_fcts.get_line_plane_name(line, plane)

        if len(lp_mask_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        lp_row = lp_mask_df.loc[lp_mask_df.index[0]]

        # identify subplots
        base_row = (pl % n_planes) * n_planes
        base_col = (li % n_lines) * n_lines

        ax_grp = ax[base_row : base_row + 2, base_col : base_col + n_sess + 1]

        # add imaging planes and masks
        imaging_planes = add_proj_and_roi_masks(
            ax_grp, lp_row, sess_cols, crop=crop, alpha=alpha, 
            proj_zorder=raster_zorder - 1
            )

        # add markings
        shared_row = base_row + 1
        shared_col = base_col + int((n_sess - 1) // 2)
        shared_sub_ax = ax[shared_row, shared_col]

        if shared_col == 0:
            shared_sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col)
        else:
            lp_sub_ax = ax[shared_row, 0]
            lp_sub_ax.set_xlim([0, 1])
            lp_sub_ax.set_ylim([0, 1])
            lp_sub_ax.text(
                0.5, 0.5, lp_name, fontweight="bold", color=lp_col, 
                ha="center", va="center", fontsize="x-large"
                )

        # add scale bar
        if n_sess < 2:
            raise NotImplementedError(
                "Scale bar placement not implemented for fewer than 2 "
                "sessions."
                )
        scale_ax = ax[shared_row, -1]
        wid_len = imaging_planes[0].shape[-1]
        add_scale_marker(
            scale_ax, side_len=wid_len, ori="horizontal", quadrant=1, 
            fontsize=20
            )

    logger.info("Rasterizing imaging plane images...", extra={"spacing": TAB})
    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):
            sub_ax = ax[i, j]
            plot_util.remove_axis_marks(sub_ax)
            if not(i % 2):
                sub_ax.set_rasterization_zorder(raster_zorder)

    # add legend
    if n_sess < 2:
        raise NotImplementedError(
            "Legend placement not implemented for fewer than 2 sessions."
            )
    add_sess_col_leg(
        ax[-1, -1], sess_cols, bbox_to_anchor=(1, 0.6), alpha=alpha, 
        fontsize="small"
        )

    return ax
示例#10
0
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, 
                               title=None, seed=None):
    """
    plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar)

    Plots pupil and running block differences.

    Required args:
        - block_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_block_diffs (list): 
                running velocity differences per block
            - run_raw_p_vals (float):
                uncorrected p-value for differences within sessions
            - run_p_vals (float):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            - pupil_block_diffs (list): 
                for pupil diameter differences per block
            - pupil_raw_p_vals (list):
                uncorrected p-value for differences within sessions
            - pupil_p_vals (list):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - seed (int): 
            seed value to use. (-1 treated as None)
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if analyspar["scale"]:
        raise NotImplementedError(
            "Expected running and pupil data to not be scaled."
            )

    if len(block_df["sess_ns"].unique()) != 1:
        raise NotImplementedError(
            "'block_df' should only contain one session number."
        )

    nanpol = None if analyspar["rem_bad"] else "omit"
    
    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    datatypes = ["run", "pupil"]
    datatype_strs = ["Running velocity", "Pupil diameter"]
    n_datatypes = len(datatypes)

    fig, ax = plt.subplots(
        1, n_datatypes, figsize=(12.7, 4), squeeze=False, 
        gridspec_kw={"wspace": 0.22}
        )

    if title is not None:
        fig.suptitle(title, y=1.2, weight="bold")

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    corr_str = "corr." if permpar["multcomp"] else "raw"

    for d, datatype in enumerate(datatypes):
        datatype_sig_str = f"{datatype_strs[d]:16}:"
        sub_ax = ax[0, d]

        lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)]
        xs, all_data, cols, dashes, p_val_texts = [], [], [], [], []
        for (line, plane), lp_df in block_df.groupby(["lines", "planes"]):
            x, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane, flat=True
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(
                line, plane
                )
            lp_names[int(x)] = line_plane_name

            if len(lp_df) == 1:
                row_idx = lp_df.index[0]
            elif len(lp_df) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")
        
            lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"]
            
            # get p-value information
            p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"]

            side = np.sign(
                math_util.mean_med(
                    lp_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                )
            sig_str = misc_analys.get_sig_symbol(
                p_val_corr, sensitivity=sensitivity, side=side, 
                tails=permpar["tails"], p_thresh=permpar["p_val"]
                )

            p_val_text = f"{p_val_corr:.2f}{sig_str}"

            datatype_sig_str = (
                f"{datatype_sig_str}{TAB}{line_plane_name}: "
                f"{p_val_corr:.5f}{sig_str:3}"
                )

            # collect information
            xs.append(x)
            all_data.append(lp_data)
            cols.append(col)
            dashes.append(dash)
            p_val_texts.append(p_val_text)

        plot_violin_data(
            sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed
            )
        
        # edit ticks
        sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA))
        sub_ax.set_xticklabels(lp_names, fontweight="bold")
        sub_ax.tick_params(axis="x", which="both", bottom=False) 

        plot_util.expand_lims(sub_ax, axis="y", prop=0.1)

        plot_util.set_interm_ticks(
            np.asarray(sub_ax), n_ticks=3, axis="y", share=False, 
            fontweight="bold", update_ticks=True
            )

        for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)):
            ylim_range = np.diff(sub_ax.get_ylim())
            y = sub_ax.get_ylim()[1] + ylim_range * 0.08
            ha = "center"
            if d == 0 and i == 0:
                x += 0.2
                p_val_text = f"{corr_str} p-val. {p_val_text}"
                ha = "right"
            sub_ax.text(
                x, y, p_val_text, fontsize=20, weight="bold", ha=ha
                )

        logger.info(datatype_sig_str, extra={"spacing": TAB})

    # add labels/titles
    for d, datatype in enumerate(datatypes):
        sub_ax = ax[0, d]
        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5
            )
     
        if d == 0:
            ylabel = "Trial differences\nU-G - D-G"
            sub_ax.set_ylabel(ylabel, weight="bold")    
        
        if datatype == "run":
            title = "Running velocity (cm/s)"
        elif datatype == "pupil":
            title = "Pupil diameter (mm)"

        sub_ax.set_title(title, weight="bold", y=1.2)
        
    return ax