예제 #1
0
def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, 
                   between_sess_sig=True, data_col="diff_stats", 
                   decoder_data=False, title=None, wide=False):
    """
    plot_sess_data(data_df, analyspar, sesspar, permpar, figpar)

    Plots errorbar data across sessions.

    Required args:
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_key} (list): data stats (me, err)
            - null_CIs (list): adjusted null CI for data
            - raw_p_vals (float): uncorrected p-value for data within 
                sessions
            - p_vals (float): p-value for data within sessions, 
                corrected for multiple comparisons and tails
            for session comparisons, e.g. 1v2:
            - raw_p_vals_{}v{} (float): uncorrected p-value for data 
                differences between sessions 
            - p_vals_{}v{} (float): p-value for data between sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - between_sess_sig (bool):
            if True, significance between sessions is logged and plotted
            default: True         
        - data_col (str):
            dataframe column in which data to plot is stored
            default: "diff_stats"
        - decoder_data (bool):
            if True, data plotted is decoder data
            default: False
        - title (str):
            plot title
            default: None
        - wide (bool):
            if True, subplots are wider
            default: False
        
    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = misc_analys.get_sess_ns(sesspar, data_df)

    figpar = sess_plot_util.fig_init_linpla(figpar)
    
    sharey = True if decoder_data else "row"
    figpar["init"]["sharey"] = sharey
    figpar["init"]["subplot_hei"] = 4.4
    figpar["init"]["gs"] = {"hspace": 0.2}
    if wide:
        figpar["init"]["subplot_wid"] = 3.0
        figpar["init"]["gs"]["wspace"] = 0.3
    else:
        figpar["init"]["subplot_wid"] = 2.6
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])

    if title is not None:
        fig.suptitle(title, y=0.97, weight="bold")

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    for pass_n in [0, 1]: # add significance markers on the second pass
        if pass_n == 1:
            logger.info(f"{comp_info}:", extra={"spacing": "\n"})
            
        for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            sess_indices = []
            lp_sess_ns = []
            for sess_n in sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) == 1:
                    sess_indices.append(rows.index[0])
                    lp_sess_ns.append(sess_n)
                elif len(rows) > 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")

            data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices])

            if pass_n == 0:
                # plot errorbars
                plot_util.plot_errorbars(
                    sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, 
                    alpha=0.8, xticks="auto", line_dash=dash
                    )

            if pass_n == 1:
                # plot CIs
                CIs = np.asarray(
                    [lp_df.loc[i, "null_CIs"] for i in sess_indices]
                    )
                CI_meds = CIs[:, 1]
                CIs = CIs[:, np.asarray([0, 2])]

                plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, 
                    width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, 
                    zorder=-12)

                # add significance markers within sessions
                y_maxes = data[:, 0] + data[:, -1]
                sides = [
                    np.sign(sub[0] - CI_med) 
                    for sub, CI_med in zip(data, CI_meds)
                    ]
                p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices]
                lp_sig_str = f"{line_plane_name:6} (within session):"
                for s, sess_n in enumerate(lp_sess_ns):
                    sig_str = misc_analys.get_sig_symbol(
                        p_vals_corr[s], sensitivity=sensitivity, side=sides[s], 
                        tails=permpar["tails"], p_thresh=permpar["p_val"]
                        )

                    if len(sig_str):
                        plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], 
                            rel_y=0.15, color=col, mark=sig_str)  

                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n}: "
                        f"{p_vals_corr[s]:.5f}{sig_str:3}"
                        )

                logger.info(lp_sig_str, extra={"spacing": TAB})
        
    if between_sess_sig:
        add_between_sess_sig(ax, data_df, permpar, data_col=data_col)

    area, ylab = True, None
    if decoder_data:
        area = False
        if "balanced" in data_col:
            ylab = "Balanced accuracy (%)" 
        else:
            ylab = "Accuracy %"

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", 
        xticks=sess_ns, modif_share=False)

    return ax
예제 #2
0
def add_between_stim_sig(ax, sub_ax_all, data_df, permpar):
    """
    add_between_stim_sig(ax, sub_ax_all, data_df, permpar)

    Plot significance markers for significant comparisons between stimulus 
    types.

    Required args:
        - ax (plt Axis): 
            axis
        - sub_ax_all (plt subplot): 
            all line/plane data subplot
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_col} (list): data stats (me, err)
            for session comparisons, e.g. 1v2:
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
    """

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})

    stimtypes = ["gabors", "visflow"]
    stim_sig_str = f"Gabors vs visual flow: "
    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):

        if len(lp_df) != 1:
            raise RuntimeError("Expected 1 row per line/plane/session.") 
        row_idx = lp_df.index[0]   

        x = [0, 1]
        data = np.vstack(
            [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()]
            ).T
        y = data[0]
        err = data[1:]
        highest = np.max(y + err[-1])

        if line != "all" and plane != "all":
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]
            mark_rel_y = 0.18
        else:
            col = plot_helper_fcts.NEARBLACK
            linpla_name = "All"
            sub_ax = sub_ax_all
            all_data_max = np.concatenate(
                [data_df[stimtypes[0]].tolist(), 
                data_df[stimtypes[1]].tolist()], 
                axis=0
                )[:, 0].max()
            highest = np.max([data[0].max(), all_data_max])
            mark_rel_y = 0.15

        p_val = lp_df.loc[row_idx, "p_vals"]
        side = np.sign(y[1] - y[0])

        sig_str = misc_analys.get_sig_symbol(
            p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], 
            p_thresh=permpar["p_val"]
            )
        stim_sig_str = \
            f"{stim_sig_str}{TAB}{linpla_name}: {p_val:.5f}{sig_str:3}"
        
        if len(sig_str):
            plot_util.plot_barplot_signif(
                sub_ax, x, highest, rel_y=0.11, color=col, lw=3, 
                mark_rel_y=mark_rel_y, mark=sig_str, 
            )
    
    logger.info(stim_sig_str, extra={"spacing": TAB})
예제 #3
0
def add_between_sess_sig(ax, data_df, permpar, data_col="diff_stats", 
                         highest=None, ctrl=False, p_val_prefix=False, 
                         dry_run=False):
    """
    add_between_sess_sig(ax, data_df, permpar)

    Plot significance markers for significant comparisons between sessions.

    Required args:
        - ax (plt Axis): 
            axis
        - data_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - {data_col} (list): data stats (me, err)
            for session comparisons, e.g. 1v2:
            - p_vals_{}v{} (float): p-value for differences between sessions, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple

    Optional args:
        - data_col (str):
            data column name in data_df
            default: "diff_stats"
        - highest (list):
            highest point for each line/plane, in order
            default: None
        - ctrl (bool): 
            if True, significance symbols should use control colour and symbol
            default: False
        - p_val_prefix (bool):
            if True, p-value columns start with data_col as a prefix 
            "{data_col}_p_vals_{}v{}".
            default: False
        - dry_run (bool):
            if True, a dry-run is done to get highest values, but nothing is 
            plotted or logged.
            default: False

    Returns:
    - highest (list):
        highest point for each line/plane, in order, after plotting
    """

    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    prefix = f"{data_col}_" if p_val_prefix else ""

    if not dry_run:
        logger.info(f"{comp_info}:", extra={"spacing": "\n"})

    for pass_n in [0, 1]: # add significance markers on the second pass
        linpla_grps = list(data_df.groupby(["lines", "planes"]))
        if highest is None:
            highest = [0] * len(linpla_grps)
        elif len(highest) != len(linpla_grps):
            raise ValueError("If highest is provided, it must contain as "
                "many values as line/plane groups in data_df.")

        for l, ((line, plane), lp_df) in enumerate(linpla_grps):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane)
            sub_ax = ax[pl, li]

            if ctrl:
                col = "gray"

            lp_sess_ns = np.sort(lp_df["sess_ns"].unique())
            for sess_n in lp_sess_ns:
                rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
                if len(rows) != 1:
                    raise RuntimeError("Expected 1 row per line/plane/session.")    

            sig_p_vals, sig_strs, sig_xs = [], [], []
            lp_sig_str = f"{line_plane_name:6} (between sessions):"
            for i, sess_n1 in enumerate(lp_sess_ns):
                row_1s = lp_df.loc[lp_df["sess_ns"] == sess_n1]
                for sess_n2 in lp_sess_ns[i + 1: ]:
                    row_2s = lp_df.loc[lp_df["sess_ns"] == sess_n2]
                    if len(row_1s) != 1 or len(row_2s) != 1:
                        raise RuntimeError(
                            "Expected exactly one row per session."
                            )
                    row1 = row_1s.loc[row_1s.index[0]]
                    row2 = row_2s.loc[row_2s.index[0]]

                    row1_highest = row1[data_col][0] + row1[data_col][-1]
                    row2_highest = row2[data_col][0] + row2[data_col][-1]      
                    highest[l] = np.nanmax(
                        [highest[l], row1_highest, row2_highest]
                        )
                    
                    if dry_run:
                        continue
                    
                    p_val = row1[
                        f"{prefix}p_vals_{int(sess_n1)}v{int(sess_n2)}"
                        ]
                    side = np.sign(row2[data_col][0] - row1[data_col][0])

                    sig_str = misc_analys.get_sig_symbol(
                        p_val, sensitivity=sensitivity, side=side, 
                        tails=permpar["tails"], p_thresh=permpar["p_val"], 
                        ctrl=ctrl
                        )

                    if len(sig_str):
                        sig_p_vals.append(p_val)
                        sig_strs.append(sig_str)
                        sig_xs.append([sess_n1, sess_n2])
                    
                    lp_sig_str = (
                        f"{lp_sig_str}{TAB} S{sess_n1}v{sess_n2}: "
                        f"{p_val:.5f}{sig_str:3}"
                        )

            if dry_run:
                continue

            n = len(sig_p_vals)
            ylims = sub_ax.get_ylim()
            prop = np.diff(ylims)[0] / 8.0

            if pass_n == 0:
                logger.info(lp_sig_str, extra={"spacing": TAB})
                if n == 0:
                    continue

                # count number of significant comparisons, and adjust y limits
                ylims = [
                    ylims[0], np.nanmax(
                        [ylims[1], highest[l] + prop * (n + 1)])
                        ]
                sub_ax.set_ylim(ylims)
            else:
                if n == 0:
                    continue

                if ctrl:
                    mark_rel_y = 0.22
                    fontsize = 14
                else:
                    mark_rel_y = 0.18
                    fontsize = 20

                # add significance markers sequentially, on second pass
                y_pos = highest[l]
                for s, (p_val, sig_str, sig_x) in enumerate(
                    zip(sig_p_vals, sig_strs, sig_xs)
                ):  
                    y_pos = highest[l] + (s + 1) * prop
                    plot_util.plot_barplot_signif(
                        sub_ax, sig_x, y_pos, rel_y=0.11, color=col, lw=3, 
                        mark_rel_y=mark_rel_y, mark=sig_str, fontsize=fontsize
                        )
                highest[l] = np.nanmax([highest[l], y_pos])

                if y_pos > ylims[1]:
                    sub_ax.set_ylim([ylims[0], y_pos * 1.1])

    return highest
예제 #4
0
def add_scatterplot_markers(sub_ax,
                            permpar,
                            regr_corr,
                            rand_corr_med,
                            slope,
                            intercept,
                            p_val,
                            col="k",
                            diffs=False):
    """
    add_scatterplot_markers(sub_ax, permpar, regr_corr, rand_corr_med, slope, 
                            intercept)
    
    Adds markers to a scatterplot plot (regression line, quadrant shading, 
    slope and significance text).

    Required args:
        - sub_ax (plt subplot): 
            subplot
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - regr_corr (float):
            regression correlation
        - rand_corr_med (float):
            median of the random correlations            
        - slope (float):
            regression slope
        - intercept (float):
            regression intercept            
        - p_val (float):
            correlation p-value for significance marker
    
    Optional args:
        - col (str):
            regression line colour
            default: "k"
        - diffs (bool):
            if True, and a horizontal line at 0 is plotted instead of the 
            identity line
            default: False
    """

    # match x and y limits
    if not diffs:
        lims = (np.min([sub_ax.get_xlim()[0],
                        sub_ax.get_ylim()[0]]),
                np.max([sub_ax.get_xlim()[1],
                        sub_ax.get_ylim()[1]]))
        sub_ax.set_xlim(lims)
        sub_ax.set_ylim(lims)

    for axis in ["x", "y"]:
        plot_util.set_interm_ticks(np.asarray(sub_ax),
                                   n_ticks=4,
                                   axis=axis,
                                   share=False,
                                   fontweight="bold",
                                   update_ticks=True)

    # plot lines
    lims = sub_ax.get_xlim()
    line_kwargs = {
        "ls": plot_helper_fcts.VDASH,
        "color": "k",
        "alpha": 0.2,
        "zorder": -15,
        "lw": 4,
    }

    if diffs:
        sub_ax.axhline(y=0, **line_kwargs)
    else:
        # identity line
        sub_ax.plot([lims[0], lims[1]], [lims[0], lims[1]], **line_kwargs)

    # shade in opposite quadrants
    sub_ax.fill_between([lims[0], 0],
                        0,
                        lims[1],
                        alpha=0.075,
                        facecolor="k",
                        edgecolor="none",
                        zorder=-16)
    sub_ax.fill_between([0, lims[1]],
                        lims[0],
                        0,
                        alpha=0.075,
                        facecolor="k",
                        edgecolor="none",
                        zorder=-16)

    # regression line
    x_line = lims
    y_line = np.array(x_line) * slope + intercept
    sub_ax.plot(x_line,
                y_line,
                ls=plot_helper_fcts.VDASH,
                color=col,
                alpha=0.8,
                lw=line_kwargs["lw"])

    # get significance info and slope
    side = np.sign(regr_corr - rand_corr_med)
    sensitivity = misc_analys.get_sensitivity(permpar)
    sig_str = misc_analys.get_sig_symbol(p_val,
                                         sensitivity=sensitivity,
                                         side=side,
                                         tails=permpar["tails"],
                                         p_thresh=permpar["p_val"])

    # write slope on subplot
    lim_range = lims[1] - lims[0]
    x_pos = lims[0] + lim_range * 0.97
    y_pos = lims[0] + lim_range * 0.05
    slope_str = f"{sig_str} slope: {slope:.2f}"
    sub_ax.text(x_pos,
                y_pos,
                slope_str,
                fontweight="bold",
                style="italic",
                fontsize=16,
                ha="right")

    return sig_str
예제 #5
0
def plot_idx_correlations(idx_corr_df,
                          permpar,
                          figpar,
                          permute="sess",
                          corr_type="corr",
                          title=None,
                          small=True):
    """
    plot_idx_correlations(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlations across sessions.

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for session comparisons, e.g. 1v2
            - {}v{}{norm_str}_corrs (float): intersession ROI index correlations
            - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI 
                index correlation standard deviation
            - {}v{}_null_CIs (list): adjusted null CI for intersession ROI 
                index correlations
            - {}v{}_raw_p_vals (float): p-value for intersession correlations
            - {}v{}_p_vals (float): p-value for intersession correlations, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - corr_type (str):
            type of correlation run, i.e. "corr" or "R_sqr"
            default: "corr"
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, smaller subplots are plotted
            default: True

    Returns:
        - ax (2D array): 
            array of subplots
    """

    norm = False
    if permute in ["sess", "all"]:
        corr_type = f"diff_{corr_type}"
        if corr_type == "diff_corr":
            norm = True
            title = title.replace("Correlations", "Normalized correlations")

    norm_str = "_norm" if norm else ""

    sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm)
    n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2)  # multiple of 2

    figpar = sess_plot_util.fig_init_linpla(figpar,
                                            kind="reg",
                                            n_sub=int(n_pairs / 2))

    figpar["init"]["ncols"] = n_pairs
    figpar["init"]["sharey"] = "row"

    figpar["init"]["gs"] = {"hspace": 0.25}
    if small:
        figpar["init"]["subplot_wid"] = 2.7
        figpar["init"]["subplot_hei"] = 4.21
        figpar["init"]["gs"]["wspace"] = 0.2
    else:
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["subplot_hei"] = 4.71
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm)
    lines = [None, None]

    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
        lines[li] = line.split("-")[0].replace("23", "2/3")

        if len(lp_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        row = lp_df.loc[lp_df.index[0]]

        lp_sig_str = f"{linpla_name:6}:"
        for s, sess_pair in enumerate(sess_pairs):
            sub_ax = ax[pl, s]
            if s == 0:
                sub_ax.set_ylim(plane_pts[pl])

            col_base = f"{sess_pair[0]}v{sess_pair[1]}"

            CI = row[f"{col_base}_null_CIs"]
            extr = np.asarray([CI[0], CI[2]])
            plot_util.plot_CI(sub_ax,
                              extr,
                              med=CI[1],
                              x=li,
                              width=0.45,
                              med_rat=0.025)

            y = row[f"{col_base}{norm_str}_corrs"]

            err = row[f"{col_base}{norm_str}_corr_stds"]
            plot_util.plot_ufo(sub_ax,
                               x=li,
                               y=y,
                               err=err,
                               color=col,
                               capsize=8)

            # add significance markers
            p_val = row[f"{col_base}_p_vals"]
            side = np.sign(y - CI[1])
            sensitivity = misc_analys.get_sensitivity(permpar)
            sig_str = misc_analys.get_sig_symbol(p_val,
                                                 sensitivity=sensitivity,
                                                 side=side,
                                                 tails=permpar["tails"],
                                                 p_thresh=permpar["p_val"])

            if len(sig_str):
                high = np.max([CI[-1], y + err])
                plot_util.add_signif_mark(sub_ax,
                                          li,
                                          high,
                                          rel_y=0.1,
                                          color=col,
                                          fontsize=24,
                                          mark=sig_str)

            sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: "
            lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}"

        logger.info(lp_sig_str, extra={"spacing": TAB})

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         lines=["", ""],
                                         xticks=[0, 1],
                                         ylab="",
                                         xlab="",
                                         kind="traces")

    xs = np.arange(len(lines))
    pad_x = 0.6 * (xs[1] - xs[0])
    for row_n in range(len(ax)):
        for col_n in range(len(ax[row_n])):
            sub_ax = ax[row_n, col_n]
            sub_ax.tick_params(axis="x", which="both", bottom=False)
            sub_ax.set_xticks(xs)
            sub_ax.set_xticklabels(lines, weight="bold")
            sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x])

            sub_ax.set_ylim(plane_pts[row_n])
            sub_ax.set_yticks(plane_pts[row_n])

            if row_n == 0:
                if col_n < len(sess_pairs):
                    s1, s2 = sess_pairs[col_n]
                    sess_pair_title = f"Session {s1} v {s2}"
                    sub_ax.set_title(sess_pair_title,
                                     fontweight="bold",
                                     y=1.07)
                sub_ax.spines["bottom"].set_visible(True)

        plot_util.set_interm_ticks(ax[row_n],
                                   3,
                                   axis="y",
                                   weight="bold",
                                   share=False,
                                   update_ticks=True)

    return ax
예제 #6
0
def plot_ex_roi_hists(ex_idx_df, sesspar, permpar, figpar, title=None):
    """
    plot_ex_roi_hists(ex_idx_df, sesspar, permpar, figpar)

    Plot example ROI histograms.

    Required args:
        - ex_idx_df (pd.DataFrame):
            dataframe with a row for the example ROI, and the following 
            columns, in addition to the basic sess_df columns:
            - rand_idx_binned (list): bin counts for the random ROI indices
            - bin_edges (list): first and last bin edge
            - CI_edges (list): confidence interval limit values
            - CI_perc (list): confidence interval percentile limits
        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - plot (str): 
            type of data to plot ("items" or "percs")
            default: "items"

    Returns:
        - ax (2D array): 
            array of subplots
    """
    
    ex_idx_df = copy.deepcopy(ex_idx_df) # add dummy binned_roi_idxs
    ex_idx_df["roi_idx_binned"] = [
        np.zeros_like(rand_idx_binned) 
        for rand_idx_binned in ex_idx_df["rand_idx_binned"].tolist()
    ]
    
    with gen_util.TempWarningFilter("invalid value", RuntimeWarning):
        ax = plot_idxs(
            ex_idx_df, sesspar, figpar, plot="items", title=title, size="tall", 
            density=True, n_bins=40)

    # adjust x axes
    for sub_ax in ax.reshape(-1):
        sub_ax.set_xticks([-0.5, 0, 0.5])
        sub_ax.set_xticklabels(["-0.5", "0", "0.5"])

    # add lines and labels
    for (line, plane), lp_df in ex_idx_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)

        if len(lp_df) == 0:
            continue
        elif len(lp_df) > 1:
            raise RuntimeError("Expected at most one row per line/plane.")
        row = lp_df.loc[lp_df.index[0]]

        sub_ax = ax[pl, li]
        xlims = sub_ax.get_xlim()

        # add CI markers
        for c, (CI_val, CI_perc) in enumerate(
            zip(row["CI_edges"], row["CI_perc"])
            ):

            sub_ax.axvline(
                CI_val, ls=plot_helper_fcts.VDASH, c="red", lw=3.0, alpha=1.0, 
                label=f"p{CI_perc:0.2f}")
            sub_ax.axvspan(
                CI_val, xlims[c], color=plot_helper_fcts.DARKRED, alpha=0.1, 
                lw=0, zorder=-13
                )
        
        ex_perc = row["roi_idx_percs"]

        sensitivity = misc_analys.get_sensitivity(permpar)
        sig_str = misc_analys.get_sig_symbol(
            ex_perc, percentile=True, sensitivity=sensitivity, 
            p_thresh=permpar["p_val"]
            )

        sub_ax.axvline(
            x=row["roi_idxs"], ls=plot_helper_fcts.VDASH, c=col, lw=3.0, 
            alpha=0.8, label=f"p{ex_perc:0.2f}{sig_str}"
            )

        sub_ax.axvline(
            x=0, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5
            )


        # reset the x limits
        sub_ax.set_xlim(xlims)

        sub_ax.legend()
    
    return ax
예제 #7
0
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, 
                               title=None, seed=None):
    """
    plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar)

    Plots pupil and running block differences.

    Required args:
        - block_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_block_diffs (list): 
                running velocity differences per block
            - run_raw_p_vals (float):
                uncorrected p-value for differences within sessions
            - run_p_vals (float):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            - pupil_block_diffs (list): 
                for pupil diameter differences per block
            - pupil_raw_p_vals (list):
                uncorrected p-value for differences within sessions
            - pupil_p_vals (list):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - seed (int): 
            seed value to use. (-1 treated as None)
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if analyspar["scale"]:
        raise NotImplementedError(
            "Expected running and pupil data to not be scaled."
            )

    if len(block_df["sess_ns"].unique()) != 1:
        raise NotImplementedError(
            "'block_df' should only contain one session number."
        )

    nanpol = None if analyspar["rem_bad"] else "omit"
    
    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    datatypes = ["run", "pupil"]
    datatype_strs = ["Running velocity", "Pupil diameter"]
    n_datatypes = len(datatypes)

    fig, ax = plt.subplots(
        1, n_datatypes, figsize=(12.7, 4), squeeze=False, 
        gridspec_kw={"wspace": 0.22}
        )

    if title is not None:
        fig.suptitle(title, y=1.2, weight="bold")

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    corr_str = "corr." if permpar["multcomp"] else "raw"

    for d, datatype in enumerate(datatypes):
        datatype_sig_str = f"{datatype_strs[d]:16}:"
        sub_ax = ax[0, d]

        lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)]
        xs, all_data, cols, dashes, p_val_texts = [], [], [], [], []
        for (line, plane), lp_df in block_df.groupby(["lines", "planes"]):
            x, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane, flat=True
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(
                line, plane
                )
            lp_names[int(x)] = line_plane_name

            if len(lp_df) == 1:
                row_idx = lp_df.index[0]
            elif len(lp_df) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")
        
            lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"]
            
            # get p-value information
            p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"]

            side = np.sign(
                math_util.mean_med(
                    lp_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                )
            sig_str = misc_analys.get_sig_symbol(
                p_val_corr, sensitivity=sensitivity, side=side, 
                tails=permpar["tails"], p_thresh=permpar["p_val"]
                )

            p_val_text = f"{p_val_corr:.2f}{sig_str}"

            datatype_sig_str = (
                f"{datatype_sig_str}{TAB}{line_plane_name}: "
                f"{p_val_corr:.5f}{sig_str:3}"
                )

            # collect information
            xs.append(x)
            all_data.append(lp_data)
            cols.append(col)
            dashes.append(dash)
            p_val_texts.append(p_val_text)

        plot_violin_data(
            sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed
            )
        
        # edit ticks
        sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA))
        sub_ax.set_xticklabels(lp_names, fontweight="bold")
        sub_ax.tick_params(axis="x", which="both", bottom=False) 

        plot_util.expand_lims(sub_ax, axis="y", prop=0.1)

        plot_util.set_interm_ticks(
            np.asarray(sub_ax), n_ticks=3, axis="y", share=False, 
            fontweight="bold", update_ticks=True
            )

        for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)):
            ylim_range = np.diff(sub_ax.get_ylim())
            y = sub_ax.get_ylim()[1] + ylim_range * 0.08
            ha = "center"
            if d == 0 and i == 0:
                x += 0.2
                p_val_text = f"{corr_str} p-val. {p_val_text}"
                ha = "right"
            sub_ax.text(
                x, y, p_val_text, fontsize=20, weight="bold", ha=ha
                )

        logger.info(datatype_sig_str, extra={"spacing": TAB})

    # add labels/titles
    for d, datatype in enumerate(datatypes):
        sub_ax = ax[0, d]
        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5
            )
     
        if d == 0:
            ylabel = "Trial differences\nU-G - D-G"
            sub_ax.set_ylabel(ylabel, weight="bold")    
        
        if datatype == "run":
            title = "Running velocity (cm/s)"
        elif datatype == "pupil":
            title = "Pupil diameter (mm)"

        sub_ax.set_title(title, weight="bold", y=1.2)
        
    return ax