def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, between_sess_sig=True, data_col="diff_stats", decoder_data=False, title=None, wide=False): """ plot_sess_data(data_df, analyspar, sesspar, permpar, figpar) Plots errorbar data across sessions. Required args: - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_key} (list): data stats (me, err) - null_CIs (list): adjusted null CI for data - raw_p_vals (float): uncorrected p-value for data within sessions - p_vals (float): p-value for data within sessions, corrected for multiple comparisons and tails for session comparisons, e.g. 1v2: - raw_p_vals_{}v{} (float): uncorrected p-value for data differences between sessions - p_vals_{}v{} (float): p-value for data between sessions, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - sesspar (dict): dictionary with keys of SessPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - between_sess_sig (bool): if True, significance between sessions is logged and plotted default: True - data_col (str): dataframe column in which data to plot is stored default: "diff_stats" - decoder_data (bool): if True, data plotted is decoder data default: False - title (str): plot title default: None - wide (bool): if True, subplots are wider default: False Returns: - ax (2D array): array of subplots """ sess_ns = misc_analys.get_sess_ns(sesspar, data_df) figpar = sess_plot_util.fig_init_linpla(figpar) sharey = True if decoder_data else "row" figpar["init"]["sharey"] = sharey figpar["init"]["subplot_hei"] = 4.4 figpar["init"]["gs"] = {"hspace": 0.2} if wide: figpar["init"]["subplot_wid"] = 3.0 figpar["init"]["gs"]["wspace"] = 0.3 else: figpar["init"]["subplot_wid"] = 2.6 figpar["init"]["gs"]["wspace"] = 0.3 fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.97, weight="bold") sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) for pass_n in [0, 1]: # add significance markers on the second pass if pass_n == 1: logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for (line, plane), lp_df in data_df.groupby(["lines", "planes"]): li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane ) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] sess_indices = [] lp_sess_ns = [] for sess_n in sess_ns: rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(rows) == 1: sess_indices.append(rows.index[0]) lp_sess_ns.append(sess_n) elif len(rows) > 1: raise RuntimeError("Expected 1 row per line/plane/session.") data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices]) if pass_n == 0: # plot errorbars plot_util.plot_errorbars( sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, alpha=0.8, xticks="auto", line_dash=dash ) if pass_n == 1: # plot CIs CIs = np.asarray( [lp_df.loc[i, "null_CIs"] for i in sess_indices] ) CI_meds = CIs[:, 1] CIs = CIs[:, np.asarray([0, 2])] plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, zorder=-12) # add significance markers within sessions y_maxes = data[:, 0] + data[:, -1] sides = [ np.sign(sub[0] - CI_med) for sub, CI_med in zip(data, CI_meds) ] p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices] lp_sig_str = f"{line_plane_name:6} (within session):" for s, sess_n in enumerate(lp_sess_ns): sig_str = misc_analys.get_sig_symbol( p_vals_corr[s], sensitivity=sensitivity, side=sides[s], tails=permpar["tails"], p_thresh=permpar["p_val"] ) if len(sig_str): plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], rel_y=0.15, color=col, mark=sig_str) lp_sig_str = ( f"{lp_sig_str}{TAB} S{sess_n}: " f"{p_vals_corr[s]:.5f}{sig_str:3}" ) logger.info(lp_sig_str, extra={"spacing": TAB}) if between_sess_sig: add_between_sess_sig(ax, data_df, permpar, data_col=data_col) area, ylab = True, None if decoder_data: area = False if "balanced" in data_col: ylab = "Balanced accuracy (%)" else: ylab = "Accuracy %" sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", xticks=sess_ns, modif_share=False) return ax
def add_between_stim_sig(ax, sub_ax_all, data_df, permpar): """ add_between_stim_sig(ax, sub_ax_all, data_df, permpar) Plot significance markers for significant comparisons between stimulus types. Required args: - ax (plt Axis): axis - sub_ax_all (plt subplot): all line/plane data subplot - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_col} (list): data stats (me, err) for session comparisons, e.g. 1v2: - p_vals_{}v{} (float): p-value for differences between sessions, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple """ sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) stimtypes = ["gabors", "visflow"] stim_sig_str = f"Gabors vs visual flow: " for (line, plane), lp_df in data_df.groupby(["lines", "planes"]): if len(lp_df) != 1: raise RuntimeError("Expected 1 row per line/plane/session.") row_idx = lp_df.index[0] x = [0, 1] data = np.vstack( [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()] ).T y = data[0] err = data[1:] highest = np.max(y + err[-1]) if line != "all" and plane != "all": li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane ) linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] mark_rel_y = 0.18 else: col = plot_helper_fcts.NEARBLACK linpla_name = "All" sub_ax = sub_ax_all all_data_max = np.concatenate( [data_df[stimtypes[0]].tolist(), data_df[stimtypes[1]].tolist()], axis=0 )[:, 0].max() highest = np.max([data[0].max(), all_data_max]) mark_rel_y = 0.15 p_val = lp_df.loc[row_idx, "p_vals"] side = np.sign(y[1] - y[0]) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) stim_sig_str = \ f"{stim_sig_str}{TAB}{linpla_name}: {p_val:.5f}{sig_str:3}" if len(sig_str): plot_util.plot_barplot_signif( sub_ax, x, highest, rel_y=0.11, color=col, lw=3, mark_rel_y=mark_rel_y, mark=sig_str, ) logger.info(stim_sig_str, extra={"spacing": TAB})
def add_between_sess_sig(ax, data_df, permpar, data_col="diff_stats", highest=None, ctrl=False, p_val_prefix=False, dry_run=False): """ add_between_sess_sig(ax, data_df, permpar) Plot significance markers for significant comparisons between sessions. Required args: - ax (plt Axis): axis - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_col} (list): data stats (me, err) for session comparisons, e.g. 1v2: - p_vals_{}v{} (float): p-value for differences between sessions, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple Optional args: - data_col (str): data column name in data_df default: "diff_stats" - highest (list): highest point for each line/plane, in order default: None - ctrl (bool): if True, significance symbols should use control colour and symbol default: False - p_val_prefix (bool): if True, p-value columns start with data_col as a prefix "{data_col}_p_vals_{}v{}". default: False - dry_run (bool): if True, a dry-run is done to get highest values, but nothing is plotted or logged. default: False Returns: - highest (list): highest point for each line/plane, in order, after plotting """ sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) prefix = f"{data_col}_" if p_val_prefix else "" if not dry_run: logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for pass_n in [0, 1]: # add significance markers on the second pass linpla_grps = list(data_df.groupby(["lines", "planes"])) if highest is None: highest = [0] * len(linpla_grps) elif len(highest) != len(linpla_grps): raise ValueError("If highest is provided, it must contain as " "many values as line/plane groups in data_df.") for l, ((line, plane), lp_df) in enumerate(linpla_grps): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] if ctrl: col = "gray" lp_sess_ns = np.sort(lp_df["sess_ns"].unique()) for sess_n in lp_sess_ns: rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(rows) != 1: raise RuntimeError("Expected 1 row per line/plane/session.") sig_p_vals, sig_strs, sig_xs = [], [], [] lp_sig_str = f"{line_plane_name:6} (between sessions):" for i, sess_n1 in enumerate(lp_sess_ns): row_1s = lp_df.loc[lp_df["sess_ns"] == sess_n1] for sess_n2 in lp_sess_ns[i + 1: ]: row_2s = lp_df.loc[lp_df["sess_ns"] == sess_n2] if len(row_1s) != 1 or len(row_2s) != 1: raise RuntimeError( "Expected exactly one row per session." ) row1 = row_1s.loc[row_1s.index[0]] row2 = row_2s.loc[row_2s.index[0]] row1_highest = row1[data_col][0] + row1[data_col][-1] row2_highest = row2[data_col][0] + row2[data_col][-1] highest[l] = np.nanmax( [highest[l], row1_highest, row2_highest] ) if dry_run: continue p_val = row1[ f"{prefix}p_vals_{int(sess_n1)}v{int(sess_n2)}" ] side = np.sign(row2[data_col][0] - row1[data_col][0]) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"], ctrl=ctrl ) if len(sig_str): sig_p_vals.append(p_val) sig_strs.append(sig_str) sig_xs.append([sess_n1, sess_n2]) lp_sig_str = ( f"{lp_sig_str}{TAB} S{sess_n1}v{sess_n2}: " f"{p_val:.5f}{sig_str:3}" ) if dry_run: continue n = len(sig_p_vals) ylims = sub_ax.get_ylim() prop = np.diff(ylims)[0] / 8.0 if pass_n == 0: logger.info(lp_sig_str, extra={"spacing": TAB}) if n == 0: continue # count number of significant comparisons, and adjust y limits ylims = [ ylims[0], np.nanmax( [ylims[1], highest[l] + prop * (n + 1)]) ] sub_ax.set_ylim(ylims) else: if n == 0: continue if ctrl: mark_rel_y = 0.22 fontsize = 14 else: mark_rel_y = 0.18 fontsize = 20 # add significance markers sequentially, on second pass y_pos = highest[l] for s, (p_val, sig_str, sig_x) in enumerate( zip(sig_p_vals, sig_strs, sig_xs) ): y_pos = highest[l] + (s + 1) * prop plot_util.plot_barplot_signif( sub_ax, sig_x, y_pos, rel_y=0.11, color=col, lw=3, mark_rel_y=mark_rel_y, mark=sig_str, fontsize=fontsize ) highest[l] = np.nanmax([highest[l], y_pos]) if y_pos > ylims[1]: sub_ax.set_ylim([ylims[0], y_pos * 1.1]) return highest
def add_scatterplot_markers(sub_ax, permpar, regr_corr, rand_corr_med, slope, intercept, p_val, col="k", diffs=False): """ add_scatterplot_markers(sub_ax, permpar, regr_corr, rand_corr_med, slope, intercept) Adds markers to a scatterplot plot (regression line, quadrant shading, slope and significance text). Required args: - sub_ax (plt subplot): subplot - permpar (dict): dictionary with keys of PermPar namedtuple - regr_corr (float): regression correlation - rand_corr_med (float): median of the random correlations - slope (float): regression slope - intercept (float): regression intercept - p_val (float): correlation p-value for significance marker Optional args: - col (str): regression line colour default: "k" - diffs (bool): if True, and a horizontal line at 0 is plotted instead of the identity line default: False """ # match x and y limits if not diffs: lims = (np.min([sub_ax.get_xlim()[0], sub_ax.get_ylim()[0]]), np.max([sub_ax.get_xlim()[1], sub_ax.get_ylim()[1]])) sub_ax.set_xlim(lims) sub_ax.set_ylim(lims) for axis in ["x", "y"]: plot_util.set_interm_ticks(np.asarray(sub_ax), n_ticks=4, axis=axis, share=False, fontweight="bold", update_ticks=True) # plot lines lims = sub_ax.get_xlim() line_kwargs = { "ls": plot_helper_fcts.VDASH, "color": "k", "alpha": 0.2, "zorder": -15, "lw": 4, } if diffs: sub_ax.axhline(y=0, **line_kwargs) else: # identity line sub_ax.plot([lims[0], lims[1]], [lims[0], lims[1]], **line_kwargs) # shade in opposite quadrants sub_ax.fill_between([lims[0], 0], 0, lims[1], alpha=0.075, facecolor="k", edgecolor="none", zorder=-16) sub_ax.fill_between([0, lims[1]], lims[0], 0, alpha=0.075, facecolor="k", edgecolor="none", zorder=-16) # regression line x_line = lims y_line = np.array(x_line) * slope + intercept sub_ax.plot(x_line, y_line, ls=plot_helper_fcts.VDASH, color=col, alpha=0.8, lw=line_kwargs["lw"]) # get significance info and slope side = np.sign(regr_corr - rand_corr_med) sensitivity = misc_analys.get_sensitivity(permpar) sig_str = misc_analys.get_sig_symbol(p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"]) # write slope on subplot lim_range = lims[1] - lims[0] x_pos = lims[0] + lim_range * 0.97 y_pos = lims[0] + lim_range * 0.05 slope_str = f"{sig_str} slope: {slope:.2f}" sub_ax.text(x_pos, y_pos, slope_str, fontweight="bold", style="italic", fontsize=16, ha="right") return sig_str
def plot_idx_correlations(idx_corr_df, permpar, figpar, permute="sess", corr_type="corr", title=None, small=True): """ plot_idx_correlations(idx_corr_df, permpar, figpar) Plots ROI USI index correlations across sessions. Required args: - idx_corr_df (pd.DataFrame): dataframe with one row per line/plane, and the following columns, in addition to the basic sess_df columns: for session comparisons, e.g. 1v2 - {}v{}{norm_str}_corrs (float): intersession ROI index correlations - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI index correlation standard deviation - {}v{}_null_CIs (list): adjusted null CI for intersession ROI index correlations - {}v{}_raw_p_vals (float): p-value for intersession correlations - {}v{}_p_vals (float): p-value for intersession correlations, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - permute (bool): type of permutation to due ("tracking", "sess" or "all") default: "sess" - corr_type (str): type of correlation run, i.e. "corr" or "R_sqr" default: "corr" - title (str): plot title default: None - small (bool): if True, smaller subplots are plotted default: True Returns: - ax (2D array): array of subplots """ norm = False if permute in ["sess", "all"]: corr_type = f"diff_{corr_type}" if corr_type == "diff_corr": norm = True title = title.replace("Correlations", "Normalized correlations") norm_str = "_norm" if norm else "" sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm) n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2) # multiple of 2 figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg", n_sub=int(n_pairs / 2)) figpar["init"]["ncols"] = n_pairs figpar["init"]["sharey"] = "row" figpar["init"]["gs"] = {"hspace": 0.25} if small: figpar["init"]["subplot_wid"] = 2.7 figpar["init"]["subplot_hei"] = 4.21 figpar["init"]["gs"]["wspace"] = 0.2 else: figpar["init"]["subplot_wid"] = 3.3 figpar["init"]["subplot_hei"] = 4.71 figpar["init"]["gs"]["wspace"] = 0.3 fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.98, weight="bold") plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm) lines = [None, None] comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) lines[li] = line.split("-")[0].replace("23", "2/3") if len(lp_df) != 1: raise RuntimeError("Expected only one row per line/plane.") row = lp_df.loc[lp_df.index[0]] lp_sig_str = f"{linpla_name:6}:" for s, sess_pair in enumerate(sess_pairs): sub_ax = ax[pl, s] if s == 0: sub_ax.set_ylim(plane_pts[pl]) col_base = f"{sess_pair[0]}v{sess_pair[1]}" CI = row[f"{col_base}_null_CIs"] extr = np.asarray([CI[0], CI[2]]) plot_util.plot_CI(sub_ax, extr, med=CI[1], x=li, width=0.45, med_rat=0.025) y = row[f"{col_base}{norm_str}_corrs"] err = row[f"{col_base}{norm_str}_corr_stds"] plot_util.plot_ufo(sub_ax, x=li, y=y, err=err, color=col, capsize=8) # add significance markers p_val = row[f"{col_base}_p_vals"] side = np.sign(y - CI[1]) sensitivity = misc_analys.get_sensitivity(permpar) sig_str = misc_analys.get_sig_symbol(p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"]) if len(sig_str): high = np.max([CI[-1], y + err]) plot_util.add_signif_mark(sub_ax, li, high, rel_y=0.1, color=col, fontsize=24, mark=sig_str) sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: " lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}" logger.info(lp_sig_str, extra={"spacing": TAB}) # Add plane, line info to plots sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=["", ""], xticks=[0, 1], ylab="", xlab="", kind="traces") xs = np.arange(len(lines)) pad_x = 0.6 * (xs[1] - xs[0]) for row_n in range(len(ax)): for col_n in range(len(ax[row_n])): sub_ax = ax[row_n, col_n] sub_ax.tick_params(axis="x", which="both", bottom=False) sub_ax.set_xticks(xs) sub_ax.set_xticklabels(lines, weight="bold") sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x]) sub_ax.set_ylim(plane_pts[row_n]) sub_ax.set_yticks(plane_pts[row_n]) if row_n == 0: if col_n < len(sess_pairs): s1, s2 = sess_pairs[col_n] sess_pair_title = f"Session {s1} v {s2}" sub_ax.set_title(sess_pair_title, fontweight="bold", y=1.07) sub_ax.spines["bottom"].set_visible(True) plot_util.set_interm_ticks(ax[row_n], 3, axis="y", weight="bold", share=False, update_ticks=True) return ax
def plot_ex_roi_hists(ex_idx_df, sesspar, permpar, figpar, title=None): """ plot_ex_roi_hists(ex_idx_df, sesspar, permpar, figpar) Plot example ROI histograms. Required args: - ex_idx_df (pd.DataFrame): dataframe with a row for the example ROI, and the following columns, in addition to the basic sess_df columns: - rand_idx_binned (list): bin counts for the random ROI indices - bin_edges (list): first and last bin edge - CI_edges (list): confidence interval limit values - CI_perc (list): confidence interval percentile limits - sesspar (dict): dictionary with keys of SessPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - plot (str): type of data to plot ("items" or "percs") default: "items" Returns: - ax (2D array): array of subplots """ ex_idx_df = copy.deepcopy(ex_idx_df) # add dummy binned_roi_idxs ex_idx_df["roi_idx_binned"] = [ np.zeros_like(rand_idx_binned) for rand_idx_binned in ex_idx_df["rand_idx_binned"].tolist() ] with gen_util.TempWarningFilter("invalid value", RuntimeWarning): ax = plot_idxs( ex_idx_df, sesspar, figpar, plot="items", title=title, size="tall", density=True, n_bins=40) # adjust x axes for sub_ax in ax.reshape(-1): sub_ax.set_xticks([-0.5, 0, 0.5]) sub_ax.set_xticklabels(["-0.5", "0", "0.5"]) # add lines and labels for (line, plane), lp_df in ex_idx_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) if len(lp_df) == 0: continue elif len(lp_df) > 1: raise RuntimeError("Expected at most one row per line/plane.") row = lp_df.loc[lp_df.index[0]] sub_ax = ax[pl, li] xlims = sub_ax.get_xlim() # add CI markers for c, (CI_val, CI_perc) in enumerate( zip(row["CI_edges"], row["CI_perc"]) ): sub_ax.axvline( CI_val, ls=plot_helper_fcts.VDASH, c="red", lw=3.0, alpha=1.0, label=f"p{CI_perc:0.2f}") sub_ax.axvspan( CI_val, xlims[c], color=plot_helper_fcts.DARKRED, alpha=0.1, lw=0, zorder=-13 ) ex_perc = row["roi_idx_percs"] sensitivity = misc_analys.get_sensitivity(permpar) sig_str = misc_analys.get_sig_symbol( ex_perc, percentile=True, sensitivity=sensitivity, p_thresh=permpar["p_val"] ) sub_ax.axvline( x=row["roi_idxs"], ls=plot_helper_fcts.VDASH, c=col, lw=3.0, alpha=0.8, label=f"p{ex_perc:0.2f}{sig_str}" ) sub_ax.axvline( x=0, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5 ) # reset the x limits sub_ax.set_xlim(xlims) sub_ax.legend() return ax
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, title=None, seed=None): """ plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar) Plots pupil and running block differences. Required args: - block_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - run_block_diffs (list): running velocity differences per block - run_raw_p_vals (float): uncorrected p-value for differences within sessions - run_p_vals (float): p-value for differences within sessions, corrected for multiple comparisons and tails - pupil_block_diffs (list): for pupil diameter differences per block - pupil_raw_p_vals (list): uncorrected p-value for differences within sessions - pupil_p_vals (list): p-value for differences within sessions, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None - seed (int): seed value to use. (-1 treated as None) default: None Returns: - ax (2D array): array of subplots """ if analyspar["scale"]: raise NotImplementedError( "Expected running and pupil data to not be scaled." ) if len(block_df["sess_ns"].unique()) != 1: raise NotImplementedError( "'block_df' should only contain one session number." ) nanpol = None if analyspar["rem_bad"] else "omit" sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) datatypes = ["run", "pupil"] datatype_strs = ["Running velocity", "Pupil diameter"] n_datatypes = len(datatypes) fig, ax = plt.subplots( 1, n_datatypes, figsize=(12.7, 4), squeeze=False, gridspec_kw={"wspace": 0.22} ) if title is not None: fig.suptitle(title, y=1.2, weight="bold") logger.info(f"{comp_info}:", extra={"spacing": "\n"}) corr_str = "corr." if permpar["multcomp"] else "raw" for d, datatype in enumerate(datatypes): datatype_sig_str = f"{datatype_strs[d]:16}:" sub_ax = ax[0, d] lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)] xs, all_data, cols, dashes, p_val_texts = [], [], [], [], [] for (line, plane), lp_df in block_df.groupby(["lines", "planes"]): x, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane, flat=True ) line_plane_name = plot_helper_fcts.get_line_plane_name( line, plane ) lp_names[int(x)] = line_plane_name if len(lp_df) == 1: row_idx = lp_df.index[0] elif len(lp_df) > 1: raise RuntimeError("Expected 1 row per line/plane/session.") lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"] # get p-value information p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"] side = np.sign( math_util.mean_med( lp_data, stats=analyspar["stats"], nanpol=nanpol ) ) sig_str = misc_analys.get_sig_symbol( p_val_corr, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) p_val_text = f"{p_val_corr:.2f}{sig_str}" datatype_sig_str = ( f"{datatype_sig_str}{TAB}{line_plane_name}: " f"{p_val_corr:.5f}{sig_str:3}" ) # collect information xs.append(x) all_data.append(lp_data) cols.append(col) dashes.append(dash) p_val_texts.append(p_val_text) plot_violin_data( sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed ) # edit ticks sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA)) sub_ax.set_xticklabels(lp_names, fontweight="bold") sub_ax.tick_params(axis="x", which="both", bottom=False) plot_util.expand_lims(sub_ax, axis="y", prop=0.1) plot_util.set_interm_ticks( np.asarray(sub_ax), n_ticks=3, axis="y", share=False, fontweight="bold", update_ticks=True ) for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)): ylim_range = np.diff(sub_ax.get_ylim()) y = sub_ax.get_ylim()[1] + ylim_range * 0.08 ha = "center" if d == 0 and i == 0: x += 0.2 p_val_text = f"{corr_str} p-val. {p_val_text}" ha = "right" sub_ax.text( x, y, p_val_text, fontsize=20, weight="bold", ha=ha ) logger.info(datatype_sig_str, extra={"spacing": TAB}) # add labels/titles for d, datatype in enumerate(datatypes): sub_ax = ax[0, d] sub_ax.axhline( y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5 ) if d == 0: ylabel = "Trial differences\nU-G - D-G" sub_ax.set_ylabel(ylabel, weight="bold") if datatype == "run": title = "Running velocity (cm/s)" elif datatype == "pupil": title = "Pupil diameter (mm)" sub_ax.set_title(title, weight="bold", y=1.2) return ax