def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, between_sess_sig=True, data_col="diff_stats", decoder_data=False, title=None, wide=False): """ plot_sess_data(data_df, analyspar, sesspar, permpar, figpar) Plots errorbar data across sessions. Required args: - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_key} (list): data stats (me, err) - null_CIs (list): adjusted null CI for data - raw_p_vals (float): uncorrected p-value for data within sessions - p_vals (float): p-value for data within sessions, corrected for multiple comparisons and tails for session comparisons, e.g. 1v2: - raw_p_vals_{}v{} (float): uncorrected p-value for data differences between sessions - p_vals_{}v{} (float): p-value for data between sessions, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - sesspar (dict): dictionary with keys of SessPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - between_sess_sig (bool): if True, significance between sessions is logged and plotted default: True - data_col (str): dataframe column in which data to plot is stored default: "diff_stats" - decoder_data (bool): if True, data plotted is decoder data default: False - title (str): plot title default: None - wide (bool): if True, subplots are wider default: False Returns: - ax (2D array): array of subplots """ sess_ns = misc_analys.get_sess_ns(sesspar, data_df) figpar = sess_plot_util.fig_init_linpla(figpar) sharey = True if decoder_data else "row" figpar["init"]["sharey"] = sharey figpar["init"]["subplot_hei"] = 4.4 figpar["init"]["gs"] = {"hspace": 0.2} if wide: figpar["init"]["subplot_wid"] = 3.0 figpar["init"]["gs"]["wspace"] = 0.3 else: figpar["init"]["subplot_wid"] = 2.6 figpar["init"]["gs"]["wspace"] = 0.3 fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.97, weight="bold") sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) for pass_n in [0, 1]: # add significance markers on the second pass if pass_n == 1: logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for (line, plane), lp_df in data_df.groupby(["lines", "planes"]): li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane ) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] sess_indices = [] lp_sess_ns = [] for sess_n in sess_ns: rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(rows) == 1: sess_indices.append(rows.index[0]) lp_sess_ns.append(sess_n) elif len(rows) > 1: raise RuntimeError("Expected 1 row per line/plane/session.") data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices]) if pass_n == 0: # plot errorbars plot_util.plot_errorbars( sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, alpha=0.8, xticks="auto", line_dash=dash ) if pass_n == 1: # plot CIs CIs = np.asarray( [lp_df.loc[i, "null_CIs"] for i in sess_indices] ) CI_meds = CIs[:, 1] CIs = CIs[:, np.asarray([0, 2])] plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, zorder=-12) # add significance markers within sessions y_maxes = data[:, 0] + data[:, -1] sides = [ np.sign(sub[0] - CI_med) for sub, CI_med in zip(data, CI_meds) ] p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices] lp_sig_str = f"{line_plane_name:6} (within session):" for s, sess_n in enumerate(lp_sess_ns): sig_str = misc_analys.get_sig_symbol( p_vals_corr[s], sensitivity=sensitivity, side=sides[s], tails=permpar["tails"], p_thresh=permpar["p_val"] ) if len(sig_str): plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], rel_y=0.15, color=col, mark=sig_str) lp_sig_str = ( f"{lp_sig_str}{TAB} S{sess_n}: " f"{p_vals_corr[s]:.5f}{sig_str:3}" ) logger.info(lp_sig_str, extra={"spacing": TAB}) if between_sess_sig: add_between_sess_sig(ax, data_df, permpar, data_col=data_col) area, ylab = True, None if decoder_data: area = False if "balanced" in data_col: ylab = "Balanced accuracy (%)" else: ylab = "Accuracy %" sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", xticks=sess_ns, modif_share=False) return ax
def add_between_stim_sig(ax, sub_ax_all, data_df, permpar): """ add_between_stim_sig(ax, sub_ax_all, data_df, permpar) Plot significance markers for significant comparisons between stimulus types. Required args: - ax (plt Axis): axis - sub_ax_all (plt subplot): all line/plane data subplot - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_col} (list): data stats (me, err) for session comparisons, e.g. 1v2: - p_vals_{}v{} (float): p-value for differences between sessions, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple """ sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) stimtypes = ["gabors", "visflow"] stim_sig_str = f"Gabors vs visual flow: " for (line, plane), lp_df in data_df.groupby(["lines", "planes"]): if len(lp_df) != 1: raise RuntimeError("Expected 1 row per line/plane/session.") row_idx = lp_df.index[0] x = [0, 1] data = np.vstack( [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()] ).T y = data[0] err = data[1:] highest = np.max(y + err[-1]) if line != "all" and plane != "all": li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane ) linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] mark_rel_y = 0.18 else: col = plot_helper_fcts.NEARBLACK linpla_name = "All" sub_ax = sub_ax_all all_data_max = np.concatenate( [data_df[stimtypes[0]].tolist(), data_df[stimtypes[1]].tolist()], axis=0 )[:, 0].max() highest = np.max([data[0].max(), all_data_max]) mark_rel_y = 0.15 p_val = lp_df.loc[row_idx, "p_vals"] side = np.sign(y[1] - y[0]) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) stim_sig_str = \ f"{stim_sig_str}{TAB}{linpla_name}: {p_val:.5f}{sig_str:3}" if len(sig_str): plot_util.plot_barplot_signif( sub_ax, x, highest, rel_y=0.11, color=col, lw=3, mark_rel_y=mark_rel_y, mark=sig_str, ) logger.info(stim_sig_str, extra={"spacing": TAB})
def add_between_sess_sig(ax, data_df, permpar, data_col="diff_stats", highest=None, ctrl=False, p_val_prefix=False, dry_run=False): """ add_between_sess_sig(ax, data_df, permpar) Plot significance markers for significant comparisons between sessions. Required args: - ax (plt Axis): axis - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_col} (list): data stats (me, err) for session comparisons, e.g. 1v2: - p_vals_{}v{} (float): p-value for differences between sessions, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple Optional args: - data_col (str): data column name in data_df default: "diff_stats" - highest (list): highest point for each line/plane, in order default: None - ctrl (bool): if True, significance symbols should use control colour and symbol default: False - p_val_prefix (bool): if True, p-value columns start with data_col as a prefix "{data_col}_p_vals_{}v{}". default: False - dry_run (bool): if True, a dry-run is done to get highest values, but nothing is plotted or logged. default: False Returns: - highest (list): highest point for each line/plane, in order, after plotting """ sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) prefix = f"{data_col}_" if p_val_prefix else "" if not dry_run: logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for pass_n in [0, 1]: # add significance markers on the second pass linpla_grps = list(data_df.groupby(["lines", "planes"])) if highest is None: highest = [0] * len(linpla_grps) elif len(highest) != len(linpla_grps): raise ValueError("If highest is provided, it must contain as " "many values as line/plane groups in data_df.") for l, ((line, plane), lp_df) in enumerate(linpla_grps): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] if ctrl: col = "gray" lp_sess_ns = np.sort(lp_df["sess_ns"].unique()) for sess_n in lp_sess_ns: rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(rows) != 1: raise RuntimeError("Expected 1 row per line/plane/session.") sig_p_vals, sig_strs, sig_xs = [], [], [] lp_sig_str = f"{line_plane_name:6} (between sessions):" for i, sess_n1 in enumerate(lp_sess_ns): row_1s = lp_df.loc[lp_df["sess_ns"] == sess_n1] for sess_n2 in lp_sess_ns[i + 1: ]: row_2s = lp_df.loc[lp_df["sess_ns"] == sess_n2] if len(row_1s) != 1 or len(row_2s) != 1: raise RuntimeError( "Expected exactly one row per session." ) row1 = row_1s.loc[row_1s.index[0]] row2 = row_2s.loc[row_2s.index[0]] row1_highest = row1[data_col][0] + row1[data_col][-1] row2_highest = row2[data_col][0] + row2[data_col][-1] highest[l] = np.nanmax( [highest[l], row1_highest, row2_highest] ) if dry_run: continue p_val = row1[ f"{prefix}p_vals_{int(sess_n1)}v{int(sess_n2)}" ] side = np.sign(row2[data_col][0] - row1[data_col][0]) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"], ctrl=ctrl ) if len(sig_str): sig_p_vals.append(p_val) sig_strs.append(sig_str) sig_xs.append([sess_n1, sess_n2]) lp_sig_str = ( f"{lp_sig_str}{TAB} S{sess_n1}v{sess_n2}: " f"{p_val:.5f}{sig_str:3}" ) if dry_run: continue n = len(sig_p_vals) ylims = sub_ax.get_ylim() prop = np.diff(ylims)[0] / 8.0 if pass_n == 0: logger.info(lp_sig_str, extra={"spacing": TAB}) if n == 0: continue # count number of significant comparisons, and adjust y limits ylims = [ ylims[0], np.nanmax( [ylims[1], highest[l] + prop * (n + 1)]) ] sub_ax.set_ylim(ylims) else: if n == 0: continue if ctrl: mark_rel_y = 0.22 fontsize = 14 else: mark_rel_y = 0.18 fontsize = 20 # add significance markers sequentially, on second pass y_pos = highest[l] for s, (p_val, sig_str, sig_x) in enumerate( zip(sig_p_vals, sig_strs, sig_xs) ): y_pos = highest[l] + (s + 1) * prop plot_util.plot_barplot_signif( sub_ax, sig_x, y_pos, rel_y=0.11, color=col, lw=3, mark_rel_y=mark_rel_y, mark=sig_str, fontsize=fontsize ) highest[l] = np.nanmax([highest[l], y_pos]) if y_pos > ylims[1]: sub_ax.set_ylim([ylims[0], y_pos * 1.1]) return highest
def plot_idx_corr_scatterplots(idx_corr_df, permpar, figpar, permute="sess", title=None): """ plot_idx_corr_scatterplots(idx_corr_df, permpar, figpar) Plots ROI USI index correlation scatterplots for individual session comparisons. Required args: - idx_corr_df (pd.DataFrame): dataframe with one row per line/plane, and the following columns, in addition to the basic sess_df columns: for correlation data (normalized if corr_type is "diff_corr") for session comparisons (x, y), e.g. 1v2 - binned_rand_stats (list): number of random correlation values per bin (xs x ys) - corr_data_xs (list): USI values for x - corr_data_ys (list): USI values for y - corrs (float): correlation between session data (x and y) - p_vals (float): p-value for correlation, corrected for multiple comparisons and tails - rand_corr_meds (float): median of the random correlations - raw_p_vals (float): p-value for intersession correlations - regr_coefs (float): regression correlation coefficient (slope) - regr_intercepts (float): regression correlation intercept - x_bin_mids (list): x mid point for each random correlation bin - y_bin_mids (list): y mid point for each random correlation bin - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - permute (bool): type of permutation to due ("tracking", "sess" or "all") default: "sess" - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ diffs = False if permute in ["sess", "all"]: diffs = True figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg") figpar["init"]["sharex"] = False figpar["init"]["sharey"] = False figpar["init"]["subplot_hei"] = 4 figpar["init"]["subplot_wid"] = 4 figpar["init"]["gs"] = {"hspace": 0.4, "wspace": 0.4} fig, ax = plot_util.init_fig(4, **figpar["init"]) if title is not None: fig.suptitle(title, fontweight="bold", y=0.97) sess_ns = None # first pass to plot for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) sub_ax = ax[pl, li] if len(lp_df) != 1: raise RuntimeError("Expected exactly one row.") lp_row = lp_df.loc[lp_df.index[0]] if sess_ns is None: sess_ns = lp_row["sess_ns"] xlabel = f"Session {sess_ns[0]} USIs" ylabel = f"Session {sess_ns[1]} USIs" if diffs: ylabel = f"Session {sess_ns[1]} - {sess_ns[0]} USIs" elif sess_ns != lp_row["sess_ns"]: raise RuntimeError("Expected all sess_ns to match.") density_data = [ lp_row["x_bin_mids"], lp_row["y_bin_mids"], np.asarray(lp_row["binned_rand_stats"]).T ] sub_ax.contour(*density_data, levels=6, cmap="Greys", zorder=-13, linewidths=4) alpha = 0.3**(len(lp_row["corr_data_xs"]) / 300) sub_ax.scatter(lp_row["corr_data_xs"], lp_row["corr_data_ys"], color=col, alpha=alpha, lw=2, s=35) # Add plane, line info to plots sess_plot_util.format_linpla_subaxes(ax, datatype="roi", xticks=None, ylab=ylabel, xlab=xlabel, kind="reg") # second pass to add plot markings comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) sig_str = "" for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] # add markers back in (removed due to kind='reg') sub_ax.tick_params(axis="x", which="both", bottom=True, top=False) sub_ax.spines["bottom"].set_visible(True) lp_row = lp_df.loc[lp_df.index[0]] p_val_corr = lp_row["p_vals"] lp_sig_str = add_scatterplot_markers(sub_ax, permpar, lp_row["corrs"], lp_row["rand_corr_meds"], lp_row["regr_coefs"], lp_row["regr_intercepts"], lp_row["p_vals"], col=col, diffs=diffs) sig_str = (f"{sig_str}{TAB}{line_plane_name}: " f"{p_val_corr:.5f}{lp_sig_str:3}") logger.info(sig_str, extra={"spacing": TAB}) return ax
def plot_idx_correlations(idx_corr_df, permpar, figpar, permute="sess", corr_type="corr", title=None, small=True): """ plot_idx_correlations(idx_corr_df, permpar, figpar) Plots ROI USI index correlations across sessions. Required args: - idx_corr_df (pd.DataFrame): dataframe with one row per line/plane, and the following columns, in addition to the basic sess_df columns: for session comparisons, e.g. 1v2 - {}v{}{norm_str}_corrs (float): intersession ROI index correlations - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI index correlation standard deviation - {}v{}_null_CIs (list): adjusted null CI for intersession ROI index correlations - {}v{}_raw_p_vals (float): p-value for intersession correlations - {}v{}_p_vals (float): p-value for intersession correlations, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - permute (bool): type of permutation to due ("tracking", "sess" or "all") default: "sess" - corr_type (str): type of correlation run, i.e. "corr" or "R_sqr" default: "corr" - title (str): plot title default: None - small (bool): if True, smaller subplots are plotted default: True Returns: - ax (2D array): array of subplots """ norm = False if permute in ["sess", "all"]: corr_type = f"diff_{corr_type}" if corr_type == "diff_corr": norm = True title = title.replace("Correlations", "Normalized correlations") norm_str = "_norm" if norm else "" sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm) n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2) # multiple of 2 figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg", n_sub=int(n_pairs / 2)) figpar["init"]["ncols"] = n_pairs figpar["init"]["sharey"] = "row" figpar["init"]["gs"] = {"hspace": 0.25} if small: figpar["init"]["subplot_wid"] = 2.7 figpar["init"]["subplot_hei"] = 4.21 figpar["init"]["gs"]["wspace"] = 0.2 else: figpar["init"]["subplot_wid"] = 3.3 figpar["init"]["subplot_hei"] = 4.71 figpar["init"]["gs"]["wspace"] = 0.3 fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.98, weight="bold") plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm) lines = [None, None] comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) lines[li] = line.split("-")[0].replace("23", "2/3") if len(lp_df) != 1: raise RuntimeError("Expected only one row per line/plane.") row = lp_df.loc[lp_df.index[0]] lp_sig_str = f"{linpla_name:6}:" for s, sess_pair in enumerate(sess_pairs): sub_ax = ax[pl, s] if s == 0: sub_ax.set_ylim(plane_pts[pl]) col_base = f"{sess_pair[0]}v{sess_pair[1]}" CI = row[f"{col_base}_null_CIs"] extr = np.asarray([CI[0], CI[2]]) plot_util.plot_CI(sub_ax, extr, med=CI[1], x=li, width=0.45, med_rat=0.025) y = row[f"{col_base}{norm_str}_corrs"] err = row[f"{col_base}{norm_str}_corr_stds"] plot_util.plot_ufo(sub_ax, x=li, y=y, err=err, color=col, capsize=8) # add significance markers p_val = row[f"{col_base}_p_vals"] side = np.sign(y - CI[1]) sensitivity = misc_analys.get_sensitivity(permpar) sig_str = misc_analys.get_sig_symbol(p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"]) if len(sig_str): high = np.max([CI[-1], y + err]) plot_util.add_signif_mark(sub_ax, li, high, rel_y=0.1, color=col, fontsize=24, mark=sig_str) sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: " lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}" logger.info(lp_sig_str, extra={"spacing": TAB}) # Add plane, line info to plots sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=["", ""], xticks=[0, 1], ylab="", xlab="", kind="traces") xs = np.arange(len(lines)) pad_x = 0.6 * (xs[1] - xs[0]) for row_n in range(len(ax)): for col_n in range(len(ax[row_n])): sub_ax = ax[row_n, col_n] sub_ax.tick_params(axis="x", which="both", bottom=False) sub_ax.set_xticks(xs) sub_ax.set_xticklabels(lines, weight="bold") sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x]) sub_ax.set_ylim(plane_pts[row_n]) sub_ax.set_yticks(plane_pts[row_n]) if row_n == 0: if col_n < len(sess_pairs): s1, s2 = sess_pairs[col_n] sess_pair_title = f"Session {s1} v {s2}" sub_ax.set_title(sess_pair_title, fontweight="bold", y=1.07) sub_ax.spines["bottom"].set_visible(True) plot_util.set_interm_ticks(ax[row_n], 3, axis="y", weight="bold", share=False, update_ticks=True) return ax
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, title=None): """ plot_perc_sig_usis(perc_sig_df, analyspar, figpar) Plots percentage of significant USIs. Required args: - perc_sig_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: for sig in ["lo", "hi"]: for low vs high ROI indices - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100) - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation over percent significant ROIs - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent sig. ROIs - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for percent sig. ROIs - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. ROIs, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - by_mouse (bool): if True, plotting is done per mouse default: False - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ perc_sig_df = perc_sig_df.copy(deep=True) nanpol = None if analyspar["rem_bad"] else "omit" sess_ns = perc_sig_df["sess_ns"].unique() if len(sess_ns) != 1: raise NotImplementedError( "Plotting function implemented for 1 session only." ) figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, sharex=True, sharey=True) figpar["init"]["sharey"] = True figpar["init"]["subplot_wid"] = 3.4 figpar["init"]["gs"] = {"wspace": 0.18} if by_mouse: figpar["init"]["subplot_hei"] = 8.4 else: figpar["init"]["subplot_hei"] = 3.5 fig, ax = plot_util.init_fig(2, **figpar["init"]) if title is not None: y = 0.98 if by_mouse else 1.07 fig.suptitle(title, y=y, weight="bold") tail_order = ["Low tail", "High tail"] tail_keys = ["lo", "hi"] chance = permpar["p_val"] / 2 * 100 ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40) n_linpla = plot_helper_fcts.N_LINPLA comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for t, (tail, key) in enumerate(zip(tail_order, tail_keys)): sub_ax = ax[0, t] sub_ax.set_title(tail, fontweight="bold") sub_ax.set_ylim(ylims) # replace bottom spine with line at 0 sub_ax.spines['bottom'].set_visible(False) sub_ax.axhline(y=0, c="k", lw=4.0) data_key = f"perc_sig_{key}_idxs" CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan) CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan) tail_sig_str = f"{tail:9}:" linpla_names = [] for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) x_index = 2 * li + pl linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) linpla_names.append(linpla_name) if len(lp_df) == 0: continue elif len(lp_df) > 1 and not by_mouse: raise RuntimeError("Expected a single row per line/plane.") lp_df = lp_df.sort_values("mouse_ns") # sort by mouse df_indices = lp_df.index.tolist() if by_mouse: # plot means or medians per mouse mouse_data = lp_df[data_key].to_numpy() mouse_cols = plot_util.get_hex_color_range( len(lp_df), col=col, interval=plot_helper_fcts.MOUSE_COL_INTERVAL ) mouse_data_mean = math_util.mean_med( mouse_data, stats=analyspar["stats"], nanpol=nanpol ) CI_dummy = np.repeat(mouse_data_mean, 2) plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, x=x_index, width=0.6, med_col=col, med_rat=0.01) else: # collect confidence interval data row = lp_df.loc[df_indices[0]] mouse_cols = [col] CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[ np.asarray([0, 2]) ] CI_meds[x_index] = row[f"{data_key}_null_CIs"][1] if by_mouse: perc_p_vals = [] rel_y = 0.05 else: tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: " rel_y = 0.1 for df_i, mouse_col in zip(df_indices, mouse_cols): # plot UFOs err = None no_line = True if not by_mouse: err = perc_sig_df.loc[df_i, f"{data_key}_stds"] no_line = False # indicate bootstrapped error with wider capsize plot_util.plot_ufo( sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err, color=mouse_col, capsize=8, no_line=no_line ) # add significance markers p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"] perc = perc_sig_df.loc[df_i, data_key] nrois = np.sum(perc_sig_df.loc[df_i, "nrois"]) side = np.sign(perc - chance) sensitivity = misc_analys.get_binom_sensitivity( nrois, null_perc=chance, side=side ) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) if len(sig_str): perc_high = perc + err if err is not None else perc plot_util.add_signif_mark(sub_ax, x_index, perc_high, rel_y=rel_y, color=mouse_col, fontsize=24, mark=sig_str) if by_mouse: perc_p_vals.append( (int(np.around(perc)), p_val, sig_str) ) else: tail_sig_str = ( f"{tail_sig_str}{p_val:.5f}{sig_str:3}" ) if by_mouse: # sort p-value logging by percentage value tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: " order = np.argsort([vals[0] for vals in perc_p_vals]) for i in order: perc, p_val, sig_str = perc_p_vals[i] perc_str = f"(~{perc}%)" tail_sig_str = ( f"{tail_sig_str}{TAB}{perc_str:6} " f"{p_val:.5f}{sig_str:3}" ) # add chance information if by_mouse: sub_ax.axhline( y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, zorder=-12 ) else: plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12) logger.info(tail_sig_str, extra={"spacing": TAB}) for sub_ax in fig.axes: sub_ax.tick_params(axis="x", which="both", bottom=False) plot_util.set_ticks( sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2) sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold") ax[0, 0].set_ylabel("%", fontweight="bold") plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True) # adjustment if tick interval is repeated in the negative if ax[0, 0].get_ylim()[0] < 0: ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]]) return ax
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, title=None, seed=None): """ plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar) Plots pupil and running block differences. Required args: - block_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - run_block_diffs (list): running velocity differences per block - run_raw_p_vals (float): uncorrected p-value for differences within sessions - run_p_vals (float): p-value for differences within sessions, corrected for multiple comparisons and tails - pupil_block_diffs (list): for pupil diameter differences per block - pupil_raw_p_vals (list): uncorrected p-value for differences within sessions - pupil_p_vals (list): p-value for differences within sessions, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None - seed (int): seed value to use. (-1 treated as None) default: None Returns: - ax (2D array): array of subplots """ if analyspar["scale"]: raise NotImplementedError( "Expected running and pupil data to not be scaled." ) if len(block_df["sess_ns"].unique()) != 1: raise NotImplementedError( "'block_df' should only contain one session number." ) nanpol = None if analyspar["rem_bad"] else "omit" sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) datatypes = ["run", "pupil"] datatype_strs = ["Running velocity", "Pupil diameter"] n_datatypes = len(datatypes) fig, ax = plt.subplots( 1, n_datatypes, figsize=(12.7, 4), squeeze=False, gridspec_kw={"wspace": 0.22} ) if title is not None: fig.suptitle(title, y=1.2, weight="bold") logger.info(f"{comp_info}:", extra={"spacing": "\n"}) corr_str = "corr." if permpar["multcomp"] else "raw" for d, datatype in enumerate(datatypes): datatype_sig_str = f"{datatype_strs[d]:16}:" sub_ax = ax[0, d] lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)] xs, all_data, cols, dashes, p_val_texts = [], [], [], [], [] for (line, plane), lp_df in block_df.groupby(["lines", "planes"]): x, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane, flat=True ) line_plane_name = plot_helper_fcts.get_line_plane_name( line, plane ) lp_names[int(x)] = line_plane_name if len(lp_df) == 1: row_idx = lp_df.index[0] elif len(lp_df) > 1: raise RuntimeError("Expected 1 row per line/plane/session.") lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"] # get p-value information p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"] side = np.sign( math_util.mean_med( lp_data, stats=analyspar["stats"], nanpol=nanpol ) ) sig_str = misc_analys.get_sig_symbol( p_val_corr, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) p_val_text = f"{p_val_corr:.2f}{sig_str}" datatype_sig_str = ( f"{datatype_sig_str}{TAB}{line_plane_name}: " f"{p_val_corr:.5f}{sig_str:3}" ) # collect information xs.append(x) all_data.append(lp_data) cols.append(col) dashes.append(dash) p_val_texts.append(p_val_text) plot_violin_data( sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed ) # edit ticks sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA)) sub_ax.set_xticklabels(lp_names, fontweight="bold") sub_ax.tick_params(axis="x", which="both", bottom=False) plot_util.expand_lims(sub_ax, axis="y", prop=0.1) plot_util.set_interm_ticks( np.asarray(sub_ax), n_ticks=3, axis="y", share=False, fontweight="bold", update_ticks=True ) for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)): ylim_range = np.diff(sub_ax.get_ylim()) y = sub_ax.get_ylim()[1] + ylim_range * 0.08 ha = "center" if d == 0 and i == 0: x += 0.2 p_val_text = f"{corr_str} p-val. {p_val_text}" ha = "right" sub_ax.text( x, y, p_val_text, fontsize=20, weight="bold", ha=ha ) logger.info(datatype_sig_str, extra={"spacing": TAB}) # add labels/titles for d, datatype in enumerate(datatypes): sub_ax = ax[0, d] sub_ax.axhline( y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5 ) if d == 0: ylabel = "Trial differences\nU-G - D-G" sub_ax.set_ylabel(ylabel, weight="bold") if datatype == "run": title = "Running velocity (cm/s)" elif datatype == "pupil": title = "Pupil diameter (mm)" sub_ax.set_title(title, weight="bold", y=1.2) return ax