Ejemplo n.º 1
0
def plot_ex_gabor_traces(ex_traces_df, stimpar, figpar, title=None):
    """
    plot_ex_gabor_traces(ex_traces_df, stimpar, figpar)

    Plots example Gabor traces.

    Required args:
        - ex_traces_df (pd.DataFrame):
            dataframe with a row for each ROI, and the following columns, 
            in addition to the basic sess_df columns: 
            - time_values (list): values for each frame, in seconds
            - roi_ns (list): selected ROI number
             - traces_sm (list): selected ROI sequence traces, smoothed, with 
                dims: seq x frames
            - trace_stat (list): selected ROI trace mean or median
        - stimpar (dict):
            dictionary with keys of StimPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
    
    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if stimpar["stimtype"] != "gabors":
        raise ValueError("Expected stimpar['stimtype'] to be 'gabors'.")

    group_columns = ["lines", "planes"]
    n_per = np.max(
        [len(lp_df) for _, lp_df in ex_traces_df.groupby(group_columns)]
        )
    per_rows, per_cols = math_util.get_near_square_divisors(n_per)
    n_per = per_rows * per_cols

    figpar = sess_plot_util.fig_init_linpla(
        figpar, kind="traces", n_sub=per_rows
        )
    figpar["init"]["subplot_hei"] = 1.36
    figpar["init"]["subplot_wid"] = 2.47
    figpar["init"]["ncols"] = per_cols * 2

    fig, ax = plot_util.init_fig(
        plot_helper_fcts.N_LINPLA * n_per, **figpar["init"]
        )
    if title is not None:
        fig.suptitle(title, y=1.03, weight="bold")

    ylims = np.full(ax.shape + (2, ), np.nan)
    
    logger.info("Plotting individual traces...", extra={"spacing": TAB})
    raster_zorder = -12
    for (line, plane), lp_df in ex_traces_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)
        for i, idx in enumerate(lp_df.index):
            row_idx = int(pl * per_rows + i % per_rows)
            col_idx = int(li * per_cols + i // per_rows)
            sub_ax = ax[row_idx, col_idx]

            ylims[row_idx, col_idx] = plot_ex_gabor_roi_traces(
                sub_ax, 
                lp_df.loc[idx],
                col=col,
                dash=dash,
                zorder=raster_zorder - 1
            )

        time_values = np.asarray(lp_df.loc[lp_df.index[-1], "time_values"])
        
    sess_plot_util.format_linpla_subaxes(ax, fluor="dff", 
        area=False, datatype="roi", sess_ns=None, xticks=None, kind="traces", 
        modif_share=False)

   # fix x ticks and lims
    for sub_ax in ax.reshape(-1):
        xlims = [time_values[0], time_values[-1]]
        xticks = np.linspace(*xlims, 6)
        sub_ax.set_xticks(xticks)
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold", skip=False)
    for sub_ax in ax.reshape(-1):
        sub_ax.set_xlim(xlims)
    
    # reset y limits
    for r in range(ax.shape[0]):
        for c in range(ax.shape[1]):
            if not np.isfinite(ylims[r, c].sum()):
                continue
            ax[r, c].set_ylim(ylims[r, c])

    plot_util.set_interm_ticks(
        ax, 2, axis="y", share=False, weight="bold", update_ticks=True
        )  

    # rasterize the gray lines
    logger.info("Rasterizing individual traces...", extra={"spacing": TAB})
    for sub_ax in ax.reshape(-1):
        sub_ax.set_rasterization_zorder(raster_zorder)

    return ax
Ejemplo n.º 2
0
def plot_stim_data_df(stim_data_df, stimpar, permpar, figpar, pop_stats=True, 
                      title=None):
    """
    plot_stim_data_df(stim_data_df, stimpar, permpar, figpar)

    Plots stimulus comparison data.

    Required args:
        - stim_stats_df (pd.DataFrame):
            dataframe with one row per line/plane and one for all line/planes 
            together, and the basic sess_df columns, in addition to, 
            for each stimtype:
            - stimtype (list): absolute fractional change statistics (me, err)
            - raw_p_vals (float): uncorrected p-value for data differences 
                between stimulus types 
            - p_vals (float): p-value for data differences between stimulus 
                types, corrected for multiple comparisons and tails
        - stimpar (dict): 
            dictionary with keys of StimPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - pop_stats (bool):
            if True, analyses are run on population statistics, and not 
            individual tracked ROIs
            default: True
        - title (str):
            plot title
            default: None
    
    Returns:
        - ax (2D array): 
            array of subplots 
            (does not include added subplot for all line/plane data together)
    """

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg")

    figpar["init"]["subplot_wid"] = 2.1
    figpar["init"]["subplot_hei"] = 4.2
    figpar["init"]["gs"] = {"hspace": 0.20, "wspace": 0.3}
    figpar["init"]["sharey"] = "row"
    
    fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"])
    fig.suptitle(title, y=0.98, weight="bold")

    sub_ax_all = fig.add_axes([1.05, 0.11, 0.3, 0.77])

    stimtypes = stimpar["stimtype"][:] # deep copy

    # indicate bootstrapped error with wider capsize
    capsize = 8 if pop_stats else 6

    lp_data = []
    cols = []
    for (line, plane), lp_df in stim_data_df.groupby(["lines", "planes"]):
        x = [0, 1]
        data = np.vstack(
            [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()]
            ).T
        y = data[0]
        err = data[1:]

        if line != "all" and plane != "all":
            li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane
                )
            alpha = 0.5
            sub_ax = ax[pl, li]
            lp_data.append(y)
            cols.append(col)
        else:
            col = plot_helper_fcts.NEARBLACK
            dash = None
            alpha = 0.2
            sub_ax = sub_ax_all
            sub_ax.set_title("all", fontweight="bold")
    
        plot_util.plot_bars(
            sub_ax, x, y=y, err=err, width=0.5, lw=None, alpha=alpha, 
            color=col, ls=dash, capsize=capsize
            )

    # add dots to the all subplot
    x_vals = np.asarray([-0.17, 0.25, -0.25, 0.17]) # to spread dots out
    lw = 4
    ms = 200
    for s, _ in enumerate(stimtypes):
        lp_stim_data = [data[s] for data in lp_data]
        sorter = np.argsort(lp_stim_data)
        for i, idx in enumerate(sorter):
            x_val = s + x_vals[i]
            # white behind
            sub_ax_all.scatter(
                x=x_val, y=lp_stim_data[idx], s=ms, linewidth=lw, alpha=0.8, 
                color="white", zorder=10
                )
            
            # colored dots
            sub_ax_all.scatter(
                x=x_val, y=lp_stim_data[idx], s=ms, alpha=0.6, linewidth=0, 
                color=cols[idx], zorder=11
                )
            
            # dot borders
            sub_ax_all.scatter(
                x=x_val, y=lp_stim_data[idx], s=ms, color="None", 
                edgecolor=cols[idx], linewidth=lw, alpha=1, zorder=12
                )

    # add between stim significance 
    add_between_stim_sig(ax, sub_ax_all, stim_data_df, permpar)

    # add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=None, 
        planes=["", ""], xticks=[0, 1], ylab="Absolute fractional change", 
        kind="reg", xlab=""
        )
    
    # adjust plot details
    stimtype_names = stimtypes[:]
    stimtype_names[stimtypes.index("visflow")] = "visual\nflow"
    for sub_ax in fig.axes:
        y_max = sub_ax.get_ylim()[1]
        sub_ax.set_ylim([0, y_max])
        sub_ax.set_xticks([0, 1])
        sub_ax.set_xticklabels(
            stimtypes, weight="bold", rotation=45, ha="right"
            )
        sub_ax.tick_params(axis="x", bottom=False)
    sub_ax_all.set_xlim(ax[0, 0].get_xlim())
        
    plot_util.set_interm_ticks(
        np.asarray(sub_ax_all), 4, axis="y", share=False, weight="bold"
        )

    return ax
Ejemplo n.º 3
0
def plot_corr_ex_data_scatterplot(sub_ax,
                                  idx_corr_norm_row,
                                  corr_name="1v2",
                                  col="k"):
    """
    plot_corr_ex_data_scatterplot(sub_ax, idx_corr_norm_row)

    Plots example random correlation data in a scatterplot showing real versus
    example randomly generated data.

    Required args:
        - sub_ax (plt subplot): 
            subplot
        - idx_corr_norm_df (pd.Series):
            dataframe series with the following columns, in addition to the 
            basic sess_df columns:

            for a specific session comparison, e.g. 1v2
            - {}v{}_corrs (float): unnormalized intersession ROI index 
                correlations
            - {}v{}_norm_corrs (float): normalized intersession ROI index 
                correlations
            - {}v{}_rand_corr_meds (float): median of randomized correlations

            - {}v{}_rand_corrs_binned (list): binned random unnormalized 
                intersession ROI index correlations
            - {}v{}_rand_corrs_bin_edges (list): bins edges
    
    Optional args:
        - corr_name (str):
            session pair correlation name, used in series columns
            default: "1v2"
        - col (str):
            color for real data
    """

    sess_pair = corr_name.split("v")

    x_perm, y_perm = np.asarray(idx_corr_norm_row[f"{corr_name}_rand_ex"])
    raw_rand_corr = idx_corr_norm_row[f"{corr_name}_rand_ex_corrs"]
    sub_ax.scatter(x_perm,
                   y_perm,
                   color="gray",
                   alpha=0.4,
                   marker="d",
                   lw=2,
                   label=f"Raw random corr: {raw_rand_corr:.2f}")

    x_data, y_data = np.asarray(idx_corr_norm_row[f"{corr_name}_corr_data"])
    raw_corr = idx_corr_norm_row[f"{corr_name}_corrs"]
    sub_ax.scatter(x_data,
                   y_data,
                   color=col,
                   alpha=0.4,
                   lw=2,
                   label=f"Raw corr: {raw_corr:.2f}")

    sub_ax.set_ylabel(
        f"USI diff. between\nsession {sess_pair[0]} and {sess_pair[1]}",
        fontweight="bold")
    sub_ax.set_xlabel(f"Session {sess_pair[0]} USIs", fontweight="bold")
    sub_ax.legend()

    plot_util.set_interm_ticks(np.asarray(sub_ax),
                               n_ticks=4,
                               axis="x",
                               share=False,
                               fontweight="bold")
Ejemplo n.º 4
0
def plot_sess_traces(data_df, analyspar, sesspar, figpar, 
                     trace_col="trace_stats", row_col="sess_ns", 
                     row_order=None, split="by_exp", title=None, size="reg"):
    """
    plot_sess_traces(data_df, analyspar, sesspar, figpar) 
    
    Plots traces from dataframe.

    Required args:
        - data_df (pd.DataFrame):
            traces data frame with, in addition to the basic sess_df columns, 
            columns specified by trace_col, row_col, and a "time_values" column
        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - sesspar (dict):
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - trace_col (str):
             dataframe column containing trace statistics, as 
             split x ROIs x frames x stats 
             default: "trace_stats"
        - row_col (str):
            dataframe column specifying the variable that defines rows 
            within each line/plane
            default: "sess_ns"
        - row_order (list):
            ordered list specifying the order in which to plot from row_col.
            If None, automatic sorting order is used.
            default: None 
        - split (str):
            data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or 
            "stim_offset"
            default: False
        - title (str):
            plot title
            default: None
        - size (str):
            subplot sizes
            default: "reg"

    Returns:
        - ax (2D array): 
            array of subplots
    """
    
    # retrieve session numbers, and infer row_order, if necessary
    sess_ns = None
    if row_col == "sess_ns":
        sess_ns = row_order
        if row_order is None:
            row_order = misc_analys.get_sess_ns(sesspar, data_df)

    elif row_order is None:
        row_order = data_df[row_col].unique()

    figpar = sess_plot_util.fig_init_linpla(
        figpar, kind="traces", n_sub=len(row_order), sharey=False
        )

    if size == "small":
        figpar["init"]["subplot_hei"] = 1.51
        figpar["init"]["subplot_wid"] = 3.7
    elif size == "wide":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 4.8
        figpar["init"]["gs"] = {"wspace": 0.3, "hspace": 0.5}
    elif size == "reg":
        figpar["init"]["subplot_hei"] = 1.36
        figpar["init"]["subplot_wid"] = 3.4
    else:
        gen_util.accepted_values_error("size", size, ["small", "wide", "reg"])

    fig, ax = plot_util.init_fig(len(row_order) * 4, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=1.0, weight="bold")

    for (line, plane), lp_df in data_df.groupby(["lines", "planes"]):
        li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for r, row_val in enumerate(row_order):
            rows = lp_df.loc[lp_df[row_col] == row_val]
            if len(rows) == 0:
                continue
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected row_order instances to be unique per line/plane."
                    )
            row = rows.loc[rows.index[0]]

            sub_ax = ax[r + pl * len(row_order), li]

            if line == "L2/3-Cux2":
                exp_col = "darkgray" # oddly, lighter than gray
            else:
                exp_col = "gray"

            plot_traces(
                sub_ax, row["time_values"], row[trace_col], split=split, 
                col=col, ls=dash, exp_col=exp_col, lab=False
                )
            
    for sub_ax in ax.reshape(-1):
        plot_util.set_minimal_ticks(sub_ax, axis="y")

    sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], 
        area=False, datatype="roi", sess_ns=sess_ns, xticks=None, 
        kind="traces", modif_share=False)

   # fix x ticks and lims
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold")
    xlims = [np.min(row["time_values"]), np.max(row["time_values"])]
    if split != "by_exp":
        xlims = [-xlims[1], xlims[1]]
    sub_ax.set_xlim(xlims)

    return ax
Ejemplo n.º 5
0
def plot_idx_correlations(idx_corr_df,
                          permpar,
                          figpar,
                          permute="sess",
                          corr_type="corr",
                          title=None,
                          small=True):
    """
    plot_idx_correlations(idx_corr_df, permpar, figpar)

    Plots ROI USI index correlations across sessions.

    Required args:
        - idx_corr_df (pd.DataFrame):
            dataframe with one row per line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for session comparisons, e.g. 1v2
            - {}v{}{norm_str}_corrs (float): intersession ROI index correlations
            - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI 
                index correlation standard deviation
            - {}v{}_null_CIs (list): adjusted null CI for intersession ROI 
                index correlations
            - {}v{}_raw_p_vals (float): p-value for intersession correlations
            - {}v{}_p_vals (float): p-value for intersession correlations, 
                corrected for multiple comparisons and tails

        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - permute (bool):
            type of permutation to due ("tracking", "sess" or "all")
            default: "sess"
        - corr_type (str):
            type of correlation run, i.e. "corr" or "R_sqr"
            default: "corr"
        - title (str):
            plot title
            default: None
        - small (bool):
            if True, smaller subplots are plotted
            default: True

    Returns:
        - ax (2D array): 
            array of subplots
    """

    norm = False
    if permute in ["sess", "all"]:
        corr_type = f"diff_{corr_type}"
        if corr_type == "diff_corr":
            norm = True
            title = title.replace("Correlations", "Normalized correlations")

    norm_str = "_norm" if norm else ""

    sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm)
    n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2)  # multiple of 2

    figpar = sess_plot_util.fig_init_linpla(figpar,
                                            kind="reg",
                                            n_sub=int(n_pairs / 2))

    figpar["init"]["ncols"] = n_pairs
    figpar["init"]["sharey"] = "row"

    figpar["init"]["gs"] = {"hspace": 0.25}
    if small:
        figpar["init"]["subplot_wid"] = 2.7
        figpar["init"]["subplot_hei"] = 4.21
        figpar["init"]["gs"]["wspace"] = 0.2
    else:
        figpar["init"]["subplot_wid"] = 3.3
        figpar["init"]["subplot_hei"] = 4.71
        figpar["init"]["gs"]["wspace"] = 0.3

    fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=0.98, weight="bold")

    plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm)
    lines = [None, None]

    comp_info = misc_analys.get_comp_info(permpar)

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
        lines[li] = line.split("-")[0].replace("23", "2/3")

        if len(lp_df) != 1:
            raise RuntimeError("Expected only one row per line/plane.")
        row = lp_df.loc[lp_df.index[0]]

        lp_sig_str = f"{linpla_name:6}:"
        for s, sess_pair in enumerate(sess_pairs):
            sub_ax = ax[pl, s]
            if s == 0:
                sub_ax.set_ylim(plane_pts[pl])

            col_base = f"{sess_pair[0]}v{sess_pair[1]}"

            CI = row[f"{col_base}_null_CIs"]
            extr = np.asarray([CI[0], CI[2]])
            plot_util.plot_CI(sub_ax,
                              extr,
                              med=CI[1],
                              x=li,
                              width=0.45,
                              med_rat=0.025)

            y = row[f"{col_base}{norm_str}_corrs"]

            err = row[f"{col_base}{norm_str}_corr_stds"]
            plot_util.plot_ufo(sub_ax,
                               x=li,
                               y=y,
                               err=err,
                               color=col,
                               capsize=8)

            # add significance markers
            p_val = row[f"{col_base}_p_vals"]
            side = np.sign(y - CI[1])
            sensitivity = misc_analys.get_sensitivity(permpar)
            sig_str = misc_analys.get_sig_symbol(p_val,
                                                 sensitivity=sensitivity,
                                                 side=side,
                                                 tails=permpar["tails"],
                                                 p_thresh=permpar["p_val"])

            if len(sig_str):
                high = np.max([CI[-1], y + err])
                plot_util.add_signif_mark(sub_ax,
                                          li,
                                          high,
                                          rel_y=0.1,
                                          color=col,
                                          fontsize=24,
                                          mark=sig_str)

            sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: "
            lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}"

        logger.info(lp_sig_str, extra={"spacing": TAB})

    # Add plane, line info to plots
    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         lines=["", ""],
                                         xticks=[0, 1],
                                         ylab="",
                                         xlab="",
                                         kind="traces")

    xs = np.arange(len(lines))
    pad_x = 0.6 * (xs[1] - xs[0])
    for row_n in range(len(ax)):
        for col_n in range(len(ax[row_n])):
            sub_ax = ax[row_n, col_n]
            sub_ax.tick_params(axis="x", which="both", bottom=False)
            sub_ax.set_xticks(xs)
            sub_ax.set_xticklabels(lines, weight="bold")
            sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x])

            sub_ax.set_ylim(plane_pts[row_n])
            sub_ax.set_yticks(plane_pts[row_n])

            if row_n == 0:
                if col_n < len(sess_pairs):
                    s1, s2 = sess_pairs[col_n]
                    sess_pair_title = f"Session {s1} v {s2}"
                    sub_ax.set_title(sess_pair_title,
                                     fontweight="bold",
                                     y=1.07)
                sub_ax.spines["bottom"].set_visible(True)

        plot_util.set_interm_ticks(ax[row_n],
                                   3,
                                   axis="y",
                                   weight="bold",
                                   share=False,
                                   update_ticks=True)

    return ax
Ejemplo n.º 6
0
def add_scatterplot_markers(sub_ax,
                            permpar,
                            regr_corr,
                            rand_corr_med,
                            slope,
                            intercept,
                            p_val,
                            col="k",
                            diffs=False):
    """
    add_scatterplot_markers(sub_ax, permpar, regr_corr, rand_corr_med, slope, 
                            intercept)
    
    Adds markers to a scatterplot plot (regression line, quadrant shading, 
    slope and significance text).

    Required args:
        - sub_ax (plt subplot): 
            subplot
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - regr_corr (float):
            regression correlation
        - rand_corr_med (float):
            median of the random correlations            
        - slope (float):
            regression slope
        - intercept (float):
            regression intercept            
        - p_val (float):
            correlation p-value for significance marker
    
    Optional args:
        - col (str):
            regression line colour
            default: "k"
        - diffs (bool):
            if True, and a horizontal line at 0 is plotted instead of the 
            identity line
            default: False
    """

    # match x and y limits
    if not diffs:
        lims = (np.min([sub_ax.get_xlim()[0],
                        sub_ax.get_ylim()[0]]),
                np.max([sub_ax.get_xlim()[1],
                        sub_ax.get_ylim()[1]]))
        sub_ax.set_xlim(lims)
        sub_ax.set_ylim(lims)

    for axis in ["x", "y"]:
        plot_util.set_interm_ticks(np.asarray(sub_ax),
                                   n_ticks=4,
                                   axis=axis,
                                   share=False,
                                   fontweight="bold",
                                   update_ticks=True)

    # plot lines
    lims = sub_ax.get_xlim()
    line_kwargs = {
        "ls": plot_helper_fcts.VDASH,
        "color": "k",
        "alpha": 0.2,
        "zorder": -15,
        "lw": 4,
    }

    if diffs:
        sub_ax.axhline(y=0, **line_kwargs)
    else:
        # identity line
        sub_ax.plot([lims[0], lims[1]], [lims[0], lims[1]], **line_kwargs)

    # shade in opposite quadrants
    sub_ax.fill_between([lims[0], 0],
                        0,
                        lims[1],
                        alpha=0.075,
                        facecolor="k",
                        edgecolor="none",
                        zorder=-16)
    sub_ax.fill_between([0, lims[1]],
                        lims[0],
                        0,
                        alpha=0.075,
                        facecolor="k",
                        edgecolor="none",
                        zorder=-16)

    # regression line
    x_line = lims
    y_line = np.array(x_line) * slope + intercept
    sub_ax.plot(x_line,
                y_line,
                ls=plot_helper_fcts.VDASH,
                color=col,
                alpha=0.8,
                lw=line_kwargs["lw"])

    # get significance info and slope
    side = np.sign(regr_corr - rand_corr_med)
    sensitivity = misc_analys.get_sensitivity(permpar)
    sig_str = misc_analys.get_sig_symbol(p_val,
                                         sensitivity=sensitivity,
                                         side=side,
                                         tails=permpar["tails"],
                                         p_thresh=permpar["p_val"])

    # write slope on subplot
    lim_range = lims[1] - lims[0]
    x_pos = lims[0] + lim_range * 0.97
    y_pos = lims[0] + lim_range * 0.05
    slope_str = f"{sig_str} slope: {slope:.2f}"
    sub_ax.text(x_pos,
                y_pos,
                slope_str,
                fontweight="bold",
                style="italic",
                fontsize=16,
                ha="right")

    return sig_str
Ejemplo n.º 7
0
def plot_corr_ex_data_histogram(sub_ax,
                                idx_corr_norm_row,
                                corr_name="1v2",
                                col="k"):
    """
    plot_corr_ex_data_histogram(sub_ax, idx_corr_norm_row)

    Plots example random correlation data in a histogram show how normalized 
    residual correlations are calculated.

    Required args:
        - sub_ax (plt subplot): 
            subplot
        - idx_corr_norm_df (pd.Series):
            dataframe series with the following columns, in addition to the 
            basic sess_df columns:

            for a specific session comparison, e.g. 1v2
            - {}v{}_corrs (float): unnormalized intersession ROI index 
                correlations
            - {}v{}_norm_corrs (float): normalized intersession ROI index 
                correlations
            - {}v{}_rand_corr_meds (float): median of randomized correlations

            - {}v{}_rand_corrs_binned (list): binned random unnormalized 
                intersession ROI index correlations
            - {}v{}_rand_corrs_bin_edges (list): bins edges
    
    Optional args:
        - corr_name (str):
            session pair correlation name, used in series columns
            default: "1v2"
        - col (str):
            color for real data
    """

    med = idx_corr_norm_row[f"{corr_name}_rand_corr_meds"]
    raw_corr = idx_corr_norm_row[f"{corr_name}_corrs"]
    norm_corr = idx_corr_norm_row[f"{corr_name}_norm_corrs"]
    binned_corrs = \
        np.asarray(idx_corr_norm_row[f"{corr_name}_rand_corrs_binned"])
    bin_edges = idx_corr_norm_row[f"{corr_name}_rand_corrs_bin_edges"]
    bin_edges = np.linspace(bin_edges[0], bin_edges[1], len(binned_corrs) + 1)

    sub_ax.hist(bin_edges[:-1],
                bin_edges,
                weights=binned_corrs,
                color="gray",
                alpha=0.45,
                density=True)
    # median line
    sub_ax.axvline(x=med, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5)

    # corr line
    sub_ax.axvline(x=raw_corr,
                   ls=plot_helper_fcts.VDASH,
                   c=col,
                   lw=3.0,
                   alpha=0.7)

    # adjust axes so that at least 1/5 of the graph is beyond the correlation value
    xlims = list(sub_ax.get_xlim())
    if raw_corr < med:
        leave_space = np.absolute(np.diff([raw_corr, xlims[1]]))[0] / 3
        xlims[0] = np.min([xlims[0], -1.08, leave_space])
        edge = -1
    else:
        leave_space = np.absolute(np.diff([raw_corr, xlims[0]]))[0] / 3
        xlims[1] = np.max([xlims[1], 1.08, leave_space])
        edge = 1

    # edge line
    sub_ax.axvline(x=edge, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5)

    # shift limits
    sub_ax.set_xlim(xlims)
    ylims = list(sub_ax.get_ylim())
    sub_ax.set_ylim(ylims[0], ylims[1] * 1.3)

    # set initial x ticks with more optimal interval
    n_ticks = 3
    xticks = plot_util.rounded_lims(xlims)
    for i, bound in zip([0, 1], [-1, 1]):
        if np.absolute(xticks[i]) >= 1:
            xticks[i] = bound
            step = np.diff(xticks)[0] / (n_ticks + 1)
            o = math_util.get_order_of_mag(step)
            new_step = np.ceil(step / 10**o) * 10**o
            xticks[1 - i] = xticks[i] - new_step * (n_ticks + 1) * bound
    sub_ax.set_xticks(xticks)

    plot_util.set_interm_ticks(np.asarray(sub_ax),
                               n_ticks=n_ticks,
                               axis="x",
                               share=False,
                               fontweight="bold")

    sub_ax.set_ylabel("Density", fontweight="bold", labelpad=10)
    sub_ax.set_xlabel("Raw correlations", fontweight="bold")
    sub_ax.set_title(f"Normalized residual\ncorrelation: {norm_corr:.2f}",
                     fontweight="bold",
                     y=1.07)
Ejemplo n.º 8
0
def plot_rand_corr_ex_data(idx_corr_norm_df, title=None):
    """
    plot_rand_corr_ex_data(idx_corr_norm_df)

    Plots example random correlation data in a scatterplot and histogram to 
    show how normalized residual correlations are calculated.

    Required args:
        - idx_corr_norm_df (pd.DataFrame):
            dataframe with one row for a line/plane, and the 
            following columns, in addition to the basic sess_df columns:

            for a specific session comparison, e.g. 1v2
            - {}v{}_corrs (float): unnormalized intersession ROI index 
                correlations
            - {}v{}_norm_corrs (float): normalized intersession ROI index 
                correlations
            - {}v{}_rand_ex_corrs (float): unnormalized intersession 
                ROI index correlations for an example of randomized data
            - {}v{}_rand_corr_meds (float): median of randomized correlations

            - {}v{}_corr_data (list): intersession values to correlate
            - {}v{}_rand_ex (list): intersession values for an example of 
                randomized data
            - {}v{}_rand_corrs_binned (list): binned random unnormalized 
                intersession ROI index correlations
            - {}v{}_rand_corrs_bin_edges (list): bins edges

    Optional args:
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    plot_types = ["scatter", "hist"]
    fig, ax = plt.subplots(nrows=len(plot_types),
                           figsize=[8.7, 9.3],
                           gridspec_kw={"hspace": 0.7})

    if len(idx_corr_norm_df) != 1:
        raise ValueError("Expected idx_corr_norm_df to contain only one row.")

    sorted_pairs = get_sorted_sess_pairs(idx_corr_norm_df, norm=True)

    if len(sorted_pairs) != 1:
        raise RuntimeError(
            "Expected to find only one pair of sessions for which to plot data."
        )
    sess_pair = sorted_pairs[0]

    row = idx_corr_norm_df.loc[idx_corr_norm_df.index[0]]
    corr_name = f"{sess_pair[0]}v{sess_pair[1]}"
    _, _, col, _ = plot_helper_fcts.get_line_plane_idxs(
        row["lines"], row["planes"])

    if title is not None:
        fig.suptitle(title, y=0.95, weight="bold")

    # plot scatterplot
    scatt_ax = ax[0]
    plot_corr_ex_data_scatterplot(scatt_ax, row, corr_name=corr_name, col=col)

    # plot histogram
    hist_ax = ax[1]
    plot_corr_ex_data_histogram(hist_ax, row, corr_name=corr_name, col=col)

    plot_util.set_interm_ticks(ax,
                               n_ticks=4,
                               axis="y",
                               share=False,
                               fontweight="bold")

    return ax
Ejemplo n.º 9
0
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, 
                       title=None):
    """
    plot_perc_sig_usis(perc_sig_df, analyspar, figpar)

    Plots percentage of significant USIs.

    Required args:
        - perc_sig_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns:
            for sig in ["lo", "hi"]: for low vs high ROI indices
            - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100)
            - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation 
                over percent significant ROIs
            - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs 
            - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent 
                sig. ROIs
            - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for 
                percent sig. ROIs
            - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. 
                ROIs, corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - by_mouse (bool):
            if True, plotting is done per mouse
            default: False
        - title (str):
            plot title
            default: None
        
    Returns:
        - ax (2D array): 
            array of subplots
    """  

    perc_sig_df = perc_sig_df.copy(deep=True)

    nanpol = None if analyspar["rem_bad"] else "omit"

    sess_ns = perc_sig_df["sess_ns"].unique()
    if len(sess_ns) != 1:
        raise NotImplementedError(
            "Plotting function implemented for 1 session only."
            )

    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, 
        sharex=True, sharey=True)

    figpar["init"]["sharey"] = True
    figpar["init"]["subplot_wid"] = 3.4
    figpar["init"]["gs"] = {"wspace": 0.18}
    if by_mouse:
        figpar["init"]["subplot_hei"] = 8.4
    else:
        figpar["init"]["subplot_hei"] = 3.5

    fig, ax = plot_util.init_fig(2, **figpar["init"])
    if title is not None:
        y = 0.98 if by_mouse else 1.07
        fig.suptitle(title, y=y, weight="bold")

    tail_order = ["Low tail", "High tail"]
    tail_keys = ["lo", "hi"]
    chance = permpar["p_val"] / 2 * 100

    ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40)
    n_linpla = plot_helper_fcts.N_LINPLA

    comp_info = misc_analys.get_comp_info(permpar)
    
    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    for t, (tail, key) in enumerate(zip(tail_order, tail_keys)):
        sub_ax = ax[0, t]
        sub_ax.set_title(tail, fontweight="bold")
        sub_ax.set_ylim(ylims)

        # replace bottom spine with line at 0
        sub_ax.spines['bottom'].set_visible(False)
        sub_ax.axhline(y=0, c="k", lw=4.0)

        data_key = f"perc_sig_{key}_idxs"

        CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan)
        CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan)

        tail_sig_str = f"{tail:9}:"
        linpla_names = []
        for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]):
            li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
            x_index = 2 * li + pl
            linpla_name = plot_helper_fcts.get_line_plane_name(line, plane)
            linpla_names.append(linpla_name)
            
            if len(lp_df) == 0:
                continue
            elif len(lp_df) > 1 and not by_mouse:
                raise RuntimeError("Expected a single row per line/plane.")

            lp_df = lp_df.sort_values("mouse_ns") # sort by mouse
            df_indices = lp_df.index.tolist()

            if by_mouse:
                # plot means or medians per mouse
                mouse_data = lp_df[data_key].to_numpy()
                mouse_cols = plot_util.get_hex_color_range(
                    len(lp_df), col=col, 
                    interval=plot_helper_fcts.MOUSE_COL_INTERVAL
                    )
                mouse_data_mean = math_util.mean_med(
                    mouse_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                CI_dummy = np.repeat(mouse_data_mean, 2)
                plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, 
                    x=x_index, width=0.6, med_col=col, med_rat=0.01)
            else:
                # collect confidence interval data
                row = lp_df.loc[df_indices[0]]
                mouse_cols = [col]
                CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[
                    np.asarray([0, 2])
                    ]
                CI_meds[x_index] = row[f"{data_key}_null_CIs"][1]

            if by_mouse:
                perc_p_vals = []
                rel_y = 0.05
            else:
                tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: "
                rel_y = 0.1

            for df_i, mouse_col in zip(df_indices, mouse_cols):
                # plot UFOs
                err = None
                no_line = True
                if not by_mouse:
                    err = perc_sig_df.loc[df_i, f"{data_key}_stds"]
                    no_line = False
                # indicate bootstrapped error with wider capsize
                plot_util.plot_ufo(
                    sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err,
                    color=mouse_col, capsize=8, no_line=no_line
                    )

                # add significance markers
                p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"]
                perc = perc_sig_df.loc[df_i, data_key]
                nrois = np.sum(perc_sig_df.loc[df_i, "nrois"])
                side = np.sign(perc - chance)
                sensitivity = misc_analys.get_binom_sensitivity(
                    nrois, null_perc=chance, side=side
                    )                
                sig_str = misc_analys.get_sig_symbol(
                    p_val, sensitivity=sensitivity, side=side, 
                    tails=permpar["tails"], p_thresh=permpar["p_val"]
                    )

                if len(sig_str):
                    perc_high = perc + err if err is not None else perc
                    plot_util.add_signif_mark(sub_ax, x_index, perc_high, 
                        rel_y=rel_y, color=mouse_col, fontsize=24, 
                        mark=sig_str) 

                if by_mouse:
                    perc_p_vals.append(
                        (int(np.around(perc)), p_val, sig_str)
                    )
                else:
                    tail_sig_str = (
                        f"{tail_sig_str}{p_val:.5f}{sig_str:3}"
                        )

            if by_mouse: # sort p-value logging by percentage value
                tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: "
                order = np.argsort([vals[0] for vals in perc_p_vals])
                for i in order:
                    perc, p_val, sig_str = perc_p_vals[i]
                    perc_str = f"(~{perc}%)"
                    tail_sig_str = (
                        f"{tail_sig_str}{TAB}{perc_str:6} "
                        f"{p_val:.5f}{sig_str:3}"
                        )
                
        # add chance information
        if by_mouse:
            sub_ax.axhline(
                y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, 
                zorder=-12
                )
        else:
            plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, 
                x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12)

        logger.info(tail_sig_str, extra={"spacing": TAB})
    
    for sub_ax in fig.axes:
        sub_ax.tick_params(axis="x", which="both", bottom=False) 
        plot_util.set_ticks(
            sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2)
        sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold")

    ax[0, 0].set_ylabel("%", fontweight="bold")
    plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True)

    # adjustment if tick interval is repeated in the negative
    if ax[0, 0].get_ylim()[0] < 0:
        ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]])

    return ax
Ejemplo n.º 10
0
def plot_idxs(idx_df, sesspar, figpar, plot="items", density=True, n_bins=40, 
              title=None, size="reg"):
    """
    plot_idxs(idx_df, sesspar, figpar)

    Returns exact color for a specific line.

    Required args:
        - idx_df (pd.DataFrame):
            dataframe with indices for different line/plane combinations, and 
            the following columns, in addition to the basic sess_df columns:
            - rand_idx_binned (list): bin counts for the random ROI indices
            - bin_edges (list): first and last bin edge
            - CI_edges (list): confidence interval limit values
            - CI_perc (list): confidence interval percentile limits
            if plot == "items":
            - roi_idx_binned (list): bin counts for the ROI indices
            if plot == "percs":
            - perc_idx_binned (list): bin counts for the ROI index percentiles
            optionally:
            - n_signif_lo (int): number of significant ROIs (low) 
            - n_signif_hi (int): number of significant ROIs (high)

        - sesspar (dict): 
            dictionary with keys of SessPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - plot (str): 
            type of data to plot ("items" or "percs")
            default: "items"
        - density (bool): 
            if True, histograms are plotted as densities
            default: True
        - n_bins (int): 
            number of bins to use in histograms
            default: 40
        - title (str): 
            plot title
            default: None
        - size (str): 
            plot size ("reg", "small" or "tall")
            default: "reg"
        
    Returns:
        - ax (2D array): 
            array of subplots
    """

    if plot == "items":
        data_key = "roi_idx_binned"
        CI_key = "CI_edges"
    elif plot == "percs":
        data_key = "perc_idx_binned"
        CI_key = "CI_perc"
    else:
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])

    sess_ns = misc_analys.get_sess_ns(sesspar, idx_df)

    n_plots = len(sess_ns) * 4
    figpar["init"]["sharey"] = "row"
    figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", 
        n_sub=len(sess_ns), sharex=(plot == "percs"))

    y = 1
    if size == "reg":
        subplot_hei = 3.2
        subplot_wid = 5.5
    elif size == "small":
        y = 1.04
        subplot_hei = 2.40
        subplot_wid = 3.75
        figpar["init"]["gs"] = {"hspace": 0.25, "wspace": 0.30}
    elif size == "tall":
        y = 0.98
        subplot_hei = 5.3
        subplot_wid = 5.55
    else:
        gen_util.accepted_values_error("size", size, ["reg", "small", "tall"])
    
    figpar["init"]["subplot_hei"] = subplot_hei
    figpar["init"]["subplot_wid"] = subplot_wid
    figpar["init"]["sharey"] = "row"
    
    fig, ax = plot_util.init_fig(n_plots, **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=y, weight="bold")

    for (line, plane), lp_df in idx_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)

        for s, sess_n in enumerate(sess_ns):
            rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(rows) == 0:
                continue
            elif len(rows) > 1:
                raise RuntimeError(
                    "Expected sess_ns to be unique per line/plane."
                    )
            row = rows.loc[rows.index[0]]

            sub_ax = ax[s + pl * len(sess_ns), li]

            # get percentage significant label
            perc_label = None
            if "n_signif_lo" in row.keys() and "n_signif_hi" in row.keys():
                n_sig_lo, n_sig_hi = row["n_signif_lo"], row["n_signif_hi"]
                nrois = np.sum(row["nrois"])
                perc_signif = np.sum([n_sig_lo, n_sig_hi]) / nrois * 100
                perc_label = (f"{perc_signif:.2f}% sig\n"
                    f"({n_sig_lo}-/{n_sig_hi}+ of {nrois})")                

            plot_stim_idx_hist(
                sub_ax, row[data_key], row[CI_key], n_bins=n_bins, 
                rand_data=row["rand_idx_binned"], 
                orig_edges=row["bin_edges"], 
                plot=plot, col=col, density=density, perc_label=perc_label)
            
            if size == "small":
                sub_ax.legend(fontsize="small")

    # Add plane, line info to plots
    y_lab = "Density" if density else f"N ROIs" 
    sess_plot_util.format_linpla_subaxes(ax, datatype="roi", ylab=y_lab, 
        xticks=None, sess_ns=None, kind="idx", modif_share=False, 
        xlab="Index", single_lab=True)

    # Add indices after setting formatting
    if plot == "percs":
        nticks = 5
        xticks = [int(np.around(x, 0)) for x in np.linspace(0, 100, nticks)]
        for sub_ax in ax[-1]:
            sub_ax.set_xticks(xticks)
            sub_ax.set_xticklabels(xticks, weight="bold")
    
    elif plot == "items":
        nticks = 3
        plot_util.set_interm_ticks(
            ax, nticks, axis="x", weight="bold", share=False, skip=False
            )
    
    else:
        gen_util.accepted_values_error("plot", plot, ["items", "percs"])
        
    return ax
Ejemplo n.º 11
0
def plot_roi_correlations(corr_df, figpar, title=None, log_scale=True):
    """
    plot_roi_correlations(corr_df, figpar)

    Plots correlation histograms.

    Required args:
        - corr_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the 
            following columns, in addition to the basic sess_df columns:
            - bin_edges (list): first and last bin edge
            - corrs_binned (list): number of correlation values per bin
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - log_scale (bool):
            if True, a near logarithmic scale is used for the y axis (with a 
            linear range to reach 0, and break marks to mark the transition 
            from linear to log range)
            default: True

    Returns:
        - ax (2D array): 
            array of subplots
    """

    sess_ns = np.arange(corr_df.sess_ns.min(), corr_df.sess_ns.max() + 1)
    n_sess = len(sess_ns)

    figpar = sess_plot_util.fig_init_linpla(figpar,
                                            kind="prog",
                                            n_sub=len(sess_ns))
    figpar["init"]["subplot_hei"] = 3.0
    figpar["init"]["subplot_wid"] = 2.8
    figpar["init"]["sharex"] = log_scale
    if log_scale:
        figpar["init"]["sharey"] = True

    fig, ax = plot_util.init_fig(4 * len(sess_ns), **figpar["init"])
    if title is not None:
        fig.suptitle(title, y=1.02, weight="bold")

    sess_plot_util.format_linpla_subaxes(ax,
                                         datatype="roi",
                                         ylab="Density",
                                         xlab="Correlation",
                                         sess_ns=sess_ns,
                                         kind="prog",
                                         single_lab=True)

    log_base = 2
    for (line, plane), lp_df in corr_df.groupby(["lines", "planes"]):
        li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane)
        for s, sess_n in enumerate(sess_ns):
            sess_rows = lp_df.loc[lp_df["sess_ns"] == sess_n]
            if len(sess_rows) == 0:
                continue
            elif len(sess_rows) > 1:
                raise RuntimeError("Expected exactly one row.")
            sess_row = sess_rows.loc[sess_rows.index[0]]

            sub_ax = ax[pl, s + li * n_sess]

            weights = np.asarray(sess_row["corrs_binned"])

            bin_edges = np.linspace(*sess_row["bin_edges"], len(weights) + 1)

            sub_ax.hist(bin_edges[:-1],
                        bin_edges,
                        weights=weights,
                        color=col,
                        alpha=0.6,
                        density=True)
            sub_ax.axvline(0,
                           ls=plot_helper_fcts.VDASH,
                           c="k",
                           lw=3.0,
                           alpha=0.5)

            sub_ax.spines["bottom"].set_visible(True)
            sub_ax.tick_params(axis="x", which="both", bottom=True, top=False)

            if log_scale:
                sub_ax.set_yscale("log", base=log_base)
                sub_ax.set_xlim(-1, 1)
            else:
                sub_ax.autoscale(axis="x", tight=True)

            sub_ax.autoscale(axis="y", tight=True)

    if log_scale:  # update x ticks
        set_symlog_scale(ax, log_base=log_base, col_per_grp=n_sess, n_ticks=4)

    else:  # update x and y ticks
        for i in range(ax.shape[0]):
            for j in range(int(ax.shape[1] / n_sess)):
                sub_axes = ax[i, j * n_sess:(j + 1) * n_sess]
                plot_util.set_interm_ticks(sub_axes,
                                           4,
                                           axis="y",
                                           share=True,
                                           update_ticks=True)

    plot_util.set_interm_ticks(ax,
                               4,
                               axis="x",
                               share=log_scale,
                               update_ticks=True,
                               fontweight="bold")

    return ax
Ejemplo n.º 12
0
def plot_pupil_run_trace_stats(trace_df, analyspar, figpar, split="by_exp", 
                               title=None):
    """
    plot_pupil_run_trace_stats(trace_df, analyspar, figpar)

    Plots pupil and running trace statistics.

    Required args:
        - trace_df (pd.DataFrame):
            dataframe with one row per session number, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_trace_stats (list): 
                running velocity trace stats (split x frames x stats (me, err))
            - run_time_values (list):
                values for each frame, in seconds
                (only 0 to stimpar.post, unless split is "by_exp")
            - pupil_trace_stats (list): 
                pupil diameter trace stats (split x frames x stats (me, err))
            - pupil_time_values (list):
                values for each frame, in seconds
                (only 0 to stimpar.post, unless split is "by_exp")    

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - split (str):
            data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or 
            "stim_offset"
            default: False
        - title (str):
            plot title
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if split != "by_exp":
        raise NotImplementedError("Only implemented split 'by_exp'.")

    if analyspar["scale"]:
        raise NotImplementedError(
            "Expected running and pupil data to not be scaled."
            )

    datatypes = ["run", "pupil"]

    figpar["init"]["subplot_wid"] = 4.2
    figpar["init"]["subplot_hei"] = 2.2
    figpar["init"]["gs"] = {"hspace": 0.3}
    figpar["init"]["ncols"] = 1
    figpar["init"]["sharey"] = False
    figpar["init"]["sharex"] = True

    fig, ax = plot_util.init_fig(len(datatypes), **figpar["init"])

    if title is not None:
        fig.suptitle(title, weight="bold", y=1.0)

    if len(trace_df) != 1:
        raise NotImplementedError(
            "Only implemented for a trace_df with one row."
            )
    row_idx = trace_df.index[0]

    exp_col = plot_util.LINCLAB_COLS["gray"]
    unexp_col = plot_util.LINCLAB_COLS["red"]
    for d, datatype in enumerate(datatypes):
        sub_ax = ax[d, 0]

        time_values = trace_df.loc[row_idx, f"{datatype}_time_values"]
        trace_stats = trace_df.loc[row_idx, f"{datatype}_trace_stats"]

        seq_plots.plot_traces(
            sub_ax, time_values, trace_stats, split=split, col=unexp_col, 
            lab=False, exp_col=exp_col, hline=False
            )
    
        if datatype == "run":
            ylabel = "Running\nvelocity\n(cm/s)"
        elif datatype == "pupil":
            ylabel = "Pupil\ndiameter\n(mm)"
        sub_ax.set_ylabel(ylabel, weight="bold")

   # fix x ticks and lims
    plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold")
    xlims = [np.min(time_values), np.max(time_values)]
    if split != "by_exp":
        xlims = [-xlims[1], xlims[1]]
    sub_ax.set_xlim(xlims)
    sub_ax.set_xlabel("Time (s)", weight="bold")

    # expand y lims a bit and fix y ticks
    for sub_ax in ax.reshape(-1):
        plot_util.expand_lims(sub_ax, axis="y", prop=0.21)

    plot_util.set_interm_ticks(
        ax, 2, axis="y", share=False, weight="bold", update_ticks=True
        )

    return ax
Ejemplo n.º 13
0
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, 
                               title=None, seed=None):
    """
    plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar)

    Plots pupil and running block differences.

    Required args:
        - block_df (pd.DataFrame):
            dataframe with one row per session/line/plane, and the following 
            columns, in addition to the basic sess_df columns: 
            - run_block_diffs (list): 
                running velocity differences per block
            - run_raw_p_vals (float):
                uncorrected p-value for differences within sessions
            - run_p_vals (float):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails
            - pupil_block_diffs (list): 
                for pupil diameter differences per block
            - pupil_raw_p_vals (list):
                uncorrected p-value for differences within sessions
            - pupil_p_vals (list):
                p-value for differences within sessions, 
                corrected for multiple comparisons and tails

        - analyspar (dict): 
            dictionary with keys of AnalysPar namedtuple
        - permpar (dict): 
            dictionary with keys of PermPar namedtuple
        - figpar (dict): 
            dictionary containing the following figure parameter dictionaries
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters

    Optional args:
        - title (str):
            plot title
            default: None
        - seed (int): 
            seed value to use. (-1 treated as None)
            default: None

    Returns:
        - ax (2D array): 
            array of subplots
    """

    if analyspar["scale"]:
        raise NotImplementedError(
            "Expected running and pupil data to not be scaled."
            )

    if len(block_df["sess_ns"].unique()) != 1:
        raise NotImplementedError(
            "'block_df' should only contain one session number."
        )

    nanpol = None if analyspar["rem_bad"] else "omit"
    
    sensitivity = misc_analys.get_sensitivity(permpar)
    comp_info = misc_analys.get_comp_info(permpar)

    datatypes = ["run", "pupil"]
    datatype_strs = ["Running velocity", "Pupil diameter"]
    n_datatypes = len(datatypes)

    fig, ax = plt.subplots(
        1, n_datatypes, figsize=(12.7, 4), squeeze=False, 
        gridspec_kw={"wspace": 0.22}
        )

    if title is not None:
        fig.suptitle(title, y=1.2, weight="bold")

    logger.info(f"{comp_info}:", extra={"spacing": "\n"})
    corr_str = "corr." if permpar["multcomp"] else "raw"

    for d, datatype in enumerate(datatypes):
        datatype_sig_str = f"{datatype_strs[d]:16}:"
        sub_ax = ax[0, d]

        lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)]
        xs, all_data, cols, dashes, p_val_texts = [], [], [], [], []
        for (line, plane), lp_df in block_df.groupby(["lines", "planes"]):
            x, col, dash = plot_helper_fcts.get_line_plane_idxs(
                line, plane, flat=True
                )
            line_plane_name = plot_helper_fcts.get_line_plane_name(
                line, plane
                )
            lp_names[int(x)] = line_plane_name

            if len(lp_df) == 1:
                row_idx = lp_df.index[0]
            elif len(lp_df) > 1:
                raise RuntimeError("Expected 1 row per line/plane/session.")
        
            lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"]
            
            # get p-value information
            p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"]

            side = np.sign(
                math_util.mean_med(
                    lp_data, stats=analyspar["stats"], nanpol=nanpol
                    )
                )
            sig_str = misc_analys.get_sig_symbol(
                p_val_corr, sensitivity=sensitivity, side=side, 
                tails=permpar["tails"], p_thresh=permpar["p_val"]
                )

            p_val_text = f"{p_val_corr:.2f}{sig_str}"

            datatype_sig_str = (
                f"{datatype_sig_str}{TAB}{line_plane_name}: "
                f"{p_val_corr:.5f}{sig_str:3}"
                )

            # collect information
            xs.append(x)
            all_data.append(lp_data)
            cols.append(col)
            dashes.append(dash)
            p_val_texts.append(p_val_text)

        plot_violin_data(
            sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed
            )
        
        # edit ticks
        sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA))
        sub_ax.set_xticklabels(lp_names, fontweight="bold")
        sub_ax.tick_params(axis="x", which="both", bottom=False) 

        plot_util.expand_lims(sub_ax, axis="y", prop=0.1)

        plot_util.set_interm_ticks(
            np.asarray(sub_ax), n_ticks=3, axis="y", share=False, 
            fontweight="bold", update_ticks=True
            )

        for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)):
            ylim_range = np.diff(sub_ax.get_ylim())
            y = sub_ax.get_ylim()[1] + ylim_range * 0.08
            ha = "center"
            if d == 0 and i == 0:
                x += 0.2
                p_val_text = f"{corr_str} p-val. {p_val_text}"
                ha = "right"
            sub_ax.text(
                x, y, p_val_text, fontsize=20, weight="bold", ha=ha
                )

        logger.info(datatype_sig_str, extra={"spacing": TAB})

    # add labels/titles
    for d, datatype in enumerate(datatypes):
        sub_ax = ax[0, d]
        sub_ax.axhline(
            y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5
            )
     
        if d == 0:
            ylabel = "Trial differences\nU-G - D-G"
            sub_ax.set_ylabel(ylabel, weight="bold")    
        
        if datatype == "run":
            title = "Running velocity (cm/s)"
        elif datatype == "pupil":
            title = "Pupil diameter (mm)"

        sub_ax.set_title(title, weight="bold", y=1.2)
        
    return ax