def get_sess_grped_diffs_df(sessions, analyspar, stimpar, basepar, permpar, split="by_exp", randst=None, parallel=False): """ get_sess_grped_diffs_df(sessions, analyspar, stimpar, basepar) Returns split difference statistics for specific sessions, grouped across mice. Required args: - sessions (list): session objects - analyspar (AnalysPar): named tuple containing analysis parameters - stimpar (StimPar): named tuple containing stimulus parameters - basepar (BasePar): named tuple containing baseline parameters - permpar (PermPar): named tuple containing permutation parameters Optional args: - split (str): how to split data: "by_exp" (all exp, all unexp), "unexp_lock" (unexp, preceeding exp), "exp_lock" (exp, preceeding unexp), "stim_onset" (grayscr, stim on), "stim_offset" (stim off, grayscr) default: "by_exp" - randst (int or np.random.RandomState): random state or seed value to use. (-1 treated as None) default: None - parallel (bool): if True, some of the analysis is run in parallel across CPU cores default: False Returns: - diffs_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - diff_stats (list): split difference stats (me, err) - null_CIs (list): adjusted null CI for split differences - raw_p_vals (float): uncorrected p-value for differences within sessions - p_vals (float): p-value for differences within sessions, corrected for multiple comparisons and tails for session comparisons, e.g. 1v2: - raw_p_vals_{}v{} (float): uncorrected p-value for differences between sessions - p_vals_{}v{} (float): p-value for differences between sessions, corrected for multiple comparisons and tails """ nanpol = None if analyspar.rem_bad else "omit" if analyspar.tracked: misc_analys.check_sessions_complete(sessions, raise_err=True) sess_diffs_df = misc_analys.get_check_sess_df(sessions, None, analyspar) initial_columns = sess_diffs_df.columns.tolist() # retrieve ROI index information args_dict = { "analyspar": analyspar, "stimpar": stimpar, "basepar": basepar, "split": split, "return_data": True, } # sess x split x ROI split_stats, split_data = gen_util.parallel_wrap(get_sess_roi_split_stats, sessions, args_dict=args_dict, parallel=parallel, zip_output=True) misc_analys.get_check_sess_df(sessions, sess_diffs_df) sess_diffs_df["roi_split_stats"] = list(split_stats) sess_diffs_df["roi_split_data"] = list(split_data) columns = initial_columns + ["diff_stats", "null_CIs"] diffs_df = pd.DataFrame(columns=columns) group_columns = ["lines", "planes", "sess_ns"] aggreg_cols = [col for col in initial_columns if col not in group_columns] for lp_grp_vals, lp_grp_df in sess_diffs_df.groupby(["lines", "planes"]): lp_grp_df = lp_grp_df.sort_values(["sess_ns", "mouse_ns"]) line, plane = lp_grp_vals lp_name = plot_helper_fcts.get_line_plane_name(line, plane) logger.info(f"Running permutation tests for {lp_name} sessions...", extra={"spacing": TAB}) # obtain ROI random split differences per session # done here to avoid OOM errors lp_rand_diffs = gen_util.parallel_wrap( get_rand_split_data, lp_grp_df["roi_split_data"].tolist(), args_list=[analyspar, permpar, randst], parallel=parallel, zip_output=False) sess_diffs = [] row_indices = [] sess_ns = sorted(lp_grp_df["sess_ns"].unique()) for sess_n in sess_ns: row_idx = len(diffs_df) row_indices.append(row_idx) sess_grp_df = lp_grp_df.loc[lp_grp_df["sess_ns"] == sess_n] grp_vals = list(lp_grp_vals) + [sess_n] for g, group_column in enumerate(group_columns): diffs_df.loc[row_idx, group_column] = grp_vals[g] # add aggregated values for initial columns diffs_df = misc_analys.aggreg_columns(sess_grp_df, diffs_df, aggreg_cols, row_idx=row_idx, in_place=True) # group ROI split stats across mice: split x ROIs split_stats = np.concatenate( sess_grp_df["roi_split_stats"].to_numpy(), axis=-1) # take diff and stats across ROIs diffs = split_stats[1] - split_stats[0] diff_stats = math_util.get_stats(diffs, stats=analyspar.stats, error=analyspar.error, nanpol=nanpol) diffs_df.at[row_idx, "diff_stats"] = diff_stats.tolist() sess_diffs.append(diffs) # group random ROI split diffs across mice, and take stat rand_idxs = [ lp_grp_df.index.tolist().index(idx) for idx in sess_grp_df.index ] rand_diffs = math_util.mean_med(np.concatenate( [lp_rand_diffs[r] for r in rand_idxs], axis=0), axis=0, stats=analyspar.stats, nanpol=nanpol) # get CIs and p-values p_val, null_CI = rand_util.get_p_val_from_rand( diff_stats[0], rand_diffs, return_CIs=True, p_thresh=permpar.p_val, tails=permpar.tails, multcomp=permpar.multcomp, nanpol=nanpol) diffs_df.loc[row_idx, "p_vals"] = p_val diffs_df.at[row_idx, "null_CIs"] = null_CI del lp_rand_diffs # free up memory # calculate p-values between sessions (0-1, 0-2, 1-2...) p_vals = rand_util.comp_vals_acr_groups(sess_diffs, n_perms=permpar.n_perms, stats=analyspar.stats, paired=analyspar.tracked, nanpol=nanpol, randst=randst) p = 0 for i, sess_n in enumerate(sess_ns): for j, sess_n2 in enumerate(sess_ns[i + 1:]): key = f"p_vals_{int(sess_n)}v{int(sess_n2)}" diffs_df.loc[row_indices[i], key] = p_vals[p] diffs_df.loc[row_indices[j + 1], key] = p_vals[p] p += 1 # add corrected p-values diffs_df = misc_analys.add_corr_p_vals(diffs_df, permpar) diffs_df["sess_ns"] = diffs_df["sess_ns"].astype(int) return diffs_df
def plot_sess_data(data_df, analyspar, sesspar, permpar, figpar, between_sess_sig=True, data_col="diff_stats", decoder_data=False, title=None, wide=False): """ plot_sess_data(data_df, analyspar, sesspar, permpar, figpar) Plots errorbar data across sessions. Required args: - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_key} (list): data stats (me, err) - null_CIs (list): adjusted null CI for data - raw_p_vals (float): uncorrected p-value for data within sessions - p_vals (float): p-value for data within sessions, corrected for multiple comparisons and tails for session comparisons, e.g. 1v2: - raw_p_vals_{}v{} (float): uncorrected p-value for data differences between sessions - p_vals_{}v{} (float): p-value for data between sessions, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - sesspar (dict): dictionary with keys of SessPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - between_sess_sig (bool): if True, significance between sessions is logged and plotted default: True - data_col (str): dataframe column in which data to plot is stored default: "diff_stats" - decoder_data (bool): if True, data plotted is decoder data default: False - title (str): plot title default: None - wide (bool): if True, subplots are wider default: False Returns: - ax (2D array): array of subplots """ sess_ns = misc_analys.get_sess_ns(sesspar, data_df) figpar = sess_plot_util.fig_init_linpla(figpar) sharey = True if decoder_data else "row" figpar["init"]["sharey"] = sharey figpar["init"]["subplot_hei"] = 4.4 figpar["init"]["gs"] = {"hspace": 0.2} if wide: figpar["init"]["subplot_wid"] = 3.0 figpar["init"]["gs"]["wspace"] = 0.3 else: figpar["init"]["subplot_wid"] = 2.6 figpar["init"]["gs"]["wspace"] = 0.3 fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.97, weight="bold") sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) for pass_n in [0, 1]: # add significance markers on the second pass if pass_n == 1: logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for (line, plane), lp_df in data_df.groupby(["lines", "planes"]): li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane ) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] sess_indices = [] lp_sess_ns = [] for sess_n in sess_ns: rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(rows) == 1: sess_indices.append(rows.index[0]) lp_sess_ns.append(sess_n) elif len(rows) > 1: raise RuntimeError("Expected 1 row per line/plane/session.") data = np.asarray([lp_df.loc[i, data_col] for i in sess_indices]) if pass_n == 0: # plot errorbars plot_util.plot_errorbars( sub_ax, data[:, 0], data[:, 1:].T, lp_sess_ns, color=col, alpha=0.8, xticks="auto", line_dash=dash ) if pass_n == 1: # plot CIs CIs = np.asarray( [lp_df.loc[i, "null_CIs"] for i in sess_indices] ) CI_meds = CIs[:, 1] CIs = CIs[:, np.asarray([0, 2])] plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=lp_sess_ns, width=0.45, color="lightgrey", med_col="gray", med_rat=0.03, zorder=-12) # add significance markers within sessions y_maxes = data[:, 0] + data[:, -1] sides = [ np.sign(sub[0] - CI_med) for sub, CI_med in zip(data, CI_meds) ] p_vals_corr = [lp_df.loc[i, "p_vals"] for i in sess_indices] lp_sig_str = f"{line_plane_name:6} (within session):" for s, sess_n in enumerate(lp_sess_ns): sig_str = misc_analys.get_sig_symbol( p_vals_corr[s], sensitivity=sensitivity, side=sides[s], tails=permpar["tails"], p_thresh=permpar["p_val"] ) if len(sig_str): plot_util.add_signif_mark(sub_ax, sess_n, y_maxes[s], rel_y=0.15, color=col, mark=sig_str) lp_sig_str = ( f"{lp_sig_str}{TAB} S{sess_n}: " f"{p_vals_corr[s]:.5f}{sig_str:3}" ) logger.info(lp_sig_str, extra={"spacing": TAB}) if between_sess_sig: add_between_sess_sig(ax, data_df, permpar, data_col=data_col) area, ylab = True, None if decoder_data: area = False if "balanced" in data_col: ylab = "Balanced accuracy (%)" else: ylab = "Accuracy %" sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], area=area, ylab=ylab, datatype="roi", sess_ns=sess_ns, kind="reg", xticks=sess_ns, modif_share=False) return ax
def add_between_stim_sig(ax, sub_ax_all, data_df, permpar): """ add_between_stim_sig(ax, sub_ax_all, data_df, permpar) Plot significance markers for significant comparisons between stimulus types. Required args: - ax (plt Axis): axis - sub_ax_all (plt subplot): all line/plane data subplot - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_col} (list): data stats (me, err) for session comparisons, e.g. 1v2: - p_vals_{}v{} (float): p-value for differences between sessions, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple """ sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) stimtypes = ["gabors", "visflow"] stim_sig_str = f"Gabors vs visual flow: " for (line, plane), lp_df in data_df.groupby(["lines", "planes"]): if len(lp_df) != 1: raise RuntimeError("Expected 1 row per line/plane/session.") row_idx = lp_df.index[0] x = [0, 1] data = np.vstack( [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()] ).T y = data[0] err = data[1:] highest = np.max(y + err[-1]) if line != "all" and plane != "all": li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane ) linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] mark_rel_y = 0.18 else: col = plot_helper_fcts.NEARBLACK linpla_name = "All" sub_ax = sub_ax_all all_data_max = np.concatenate( [data_df[stimtypes[0]].tolist(), data_df[stimtypes[1]].tolist()], axis=0 )[:, 0].max() highest = np.max([data[0].max(), all_data_max]) mark_rel_y = 0.15 p_val = lp_df.loc[row_idx, "p_vals"] side = np.sign(y[1] - y[0]) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) stim_sig_str = \ f"{stim_sig_str}{TAB}{linpla_name}: {p_val:.5f}{sig_str:3}" if len(sig_str): plot_util.plot_barplot_signif( sub_ax, x, highest, rel_y=0.11, color=col, lw=3, mark_rel_y=mark_rel_y, mark=sig_str, ) logger.info(stim_sig_str, extra={"spacing": TAB})
def plot_idx_corr_scatterplots(idx_corr_df, permpar, figpar, permute="sess", title=None): """ plot_idx_corr_scatterplots(idx_corr_df, permpar, figpar) Plots ROI USI index correlation scatterplots for individual session comparisons. Required args: - idx_corr_df (pd.DataFrame): dataframe with one row per line/plane, and the following columns, in addition to the basic sess_df columns: for correlation data (normalized if corr_type is "diff_corr") for session comparisons (x, y), e.g. 1v2 - binned_rand_stats (list): number of random correlation values per bin (xs x ys) - corr_data_xs (list): USI values for x - corr_data_ys (list): USI values for y - corrs (float): correlation between session data (x and y) - p_vals (float): p-value for correlation, corrected for multiple comparisons and tails - rand_corr_meds (float): median of the random correlations - raw_p_vals (float): p-value for intersession correlations - regr_coefs (float): regression correlation coefficient (slope) - regr_intercepts (float): regression correlation intercept - x_bin_mids (list): x mid point for each random correlation bin - y_bin_mids (list): y mid point for each random correlation bin - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - permute (bool): type of permutation to due ("tracking", "sess" or "all") default: "sess" - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ diffs = False if permute in ["sess", "all"]: diffs = True figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg") figpar["init"]["sharex"] = False figpar["init"]["sharey"] = False figpar["init"]["subplot_hei"] = 4 figpar["init"]["subplot_wid"] = 4 figpar["init"]["gs"] = {"hspace": 0.4, "wspace": 0.4} fig, ax = plot_util.init_fig(4, **figpar["init"]) if title is not None: fig.suptitle(title, fontweight="bold", y=0.97) sess_ns = None # first pass to plot for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) sub_ax = ax[pl, li] if len(lp_df) != 1: raise RuntimeError("Expected exactly one row.") lp_row = lp_df.loc[lp_df.index[0]] if sess_ns is None: sess_ns = lp_row["sess_ns"] xlabel = f"Session {sess_ns[0]} USIs" ylabel = f"Session {sess_ns[1]} USIs" if diffs: ylabel = f"Session {sess_ns[1]} - {sess_ns[0]} USIs" elif sess_ns != lp_row["sess_ns"]: raise RuntimeError("Expected all sess_ns to match.") density_data = [ lp_row["x_bin_mids"], lp_row["y_bin_mids"], np.asarray(lp_row["binned_rand_stats"]).T ] sub_ax.contour(*density_data, levels=6, cmap="Greys", zorder=-13, linewidths=4) alpha = 0.3**(len(lp_row["corr_data_xs"]) / 300) sub_ax.scatter(lp_row["corr_data_xs"], lp_row["corr_data_ys"], color=col, alpha=alpha, lw=2, s=35) # Add plane, line info to plots sess_plot_util.format_linpla_subaxes(ax, datatype="roi", xticks=None, ylab=ylabel, xlab=xlabel, kind="reg") # second pass to add plot markings comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) sig_str = "" for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] # add markers back in (removed due to kind='reg') sub_ax.tick_params(axis="x", which="both", bottom=True, top=False) sub_ax.spines["bottom"].set_visible(True) lp_row = lp_df.loc[lp_df.index[0]] p_val_corr = lp_row["p_vals"] lp_sig_str = add_scatterplot_markers(sub_ax, permpar, lp_row["corrs"], lp_row["rand_corr_meds"], lp_row["regr_coefs"], lp_row["regr_intercepts"], lp_row["p_vals"], col=col, diffs=diffs) sig_str = (f"{sig_str}{TAB}{line_plane_name}: " f"{p_val_corr:.5f}{lp_sig_str:3}") logger.info(sig_str, extra={"spacing": TAB}) return ax
def add_between_sess_sig(ax, data_df, permpar, data_col="diff_stats", highest=None, ctrl=False, p_val_prefix=False, dry_run=False): """ add_between_sess_sig(ax, data_df, permpar) Plot significance markers for significant comparisons between sessions. Required args: - ax (plt Axis): axis - data_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - {data_col} (list): data stats (me, err) for session comparisons, e.g. 1v2: - p_vals_{}v{} (float): p-value for differences between sessions, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple Optional args: - data_col (str): data column name in data_df default: "diff_stats" - highest (list): highest point for each line/plane, in order default: None - ctrl (bool): if True, significance symbols should use control colour and symbol default: False - p_val_prefix (bool): if True, p-value columns start with data_col as a prefix "{data_col}_p_vals_{}v{}". default: False - dry_run (bool): if True, a dry-run is done to get highest values, but nothing is plotted or logged. default: False Returns: - highest (list): highest point for each line/plane, in order, after plotting """ sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) prefix = f"{data_col}_" if p_val_prefix else "" if not dry_run: logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for pass_n in [0, 1]: # add significance markers on the second pass linpla_grps = list(data_df.groupby(["lines", "planes"])) if highest is None: highest = [0] * len(linpla_grps) elif len(highest) != len(linpla_grps): raise ValueError("If highest is provided, it must contain as " "many values as line/plane groups in data_df.") for l, ((line, plane), lp_df) in enumerate(linpla_grps): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) line_plane_name = plot_helper_fcts.get_line_plane_name(line, plane) sub_ax = ax[pl, li] if ctrl: col = "gray" lp_sess_ns = np.sort(lp_df["sess_ns"].unique()) for sess_n in lp_sess_ns: rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(rows) != 1: raise RuntimeError("Expected 1 row per line/plane/session.") sig_p_vals, sig_strs, sig_xs = [], [], [] lp_sig_str = f"{line_plane_name:6} (between sessions):" for i, sess_n1 in enumerate(lp_sess_ns): row_1s = lp_df.loc[lp_df["sess_ns"] == sess_n1] for sess_n2 in lp_sess_ns[i + 1: ]: row_2s = lp_df.loc[lp_df["sess_ns"] == sess_n2] if len(row_1s) != 1 or len(row_2s) != 1: raise RuntimeError( "Expected exactly one row per session." ) row1 = row_1s.loc[row_1s.index[0]] row2 = row_2s.loc[row_2s.index[0]] row1_highest = row1[data_col][0] + row1[data_col][-1] row2_highest = row2[data_col][0] + row2[data_col][-1] highest[l] = np.nanmax( [highest[l], row1_highest, row2_highest] ) if dry_run: continue p_val = row1[ f"{prefix}p_vals_{int(sess_n1)}v{int(sess_n2)}" ] side = np.sign(row2[data_col][0] - row1[data_col][0]) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"], ctrl=ctrl ) if len(sig_str): sig_p_vals.append(p_val) sig_strs.append(sig_str) sig_xs.append([sess_n1, sess_n2]) lp_sig_str = ( f"{lp_sig_str}{TAB} S{sess_n1}v{sess_n2}: " f"{p_val:.5f}{sig_str:3}" ) if dry_run: continue n = len(sig_p_vals) ylims = sub_ax.get_ylim() prop = np.diff(ylims)[0] / 8.0 if pass_n == 0: logger.info(lp_sig_str, extra={"spacing": TAB}) if n == 0: continue # count number of significant comparisons, and adjust y limits ylims = [ ylims[0], np.nanmax( [ylims[1], highest[l] + prop * (n + 1)]) ] sub_ax.set_ylim(ylims) else: if n == 0: continue if ctrl: mark_rel_y = 0.22 fontsize = 14 else: mark_rel_y = 0.18 fontsize = 20 # add significance markers sequentially, on second pass y_pos = highest[l] for s, (p_val, sig_str, sig_x) in enumerate( zip(sig_p_vals, sig_strs, sig_xs) ): y_pos = highest[l] + (s + 1) * prop plot_util.plot_barplot_signif( sub_ax, sig_x, y_pos, rel_y=0.11, color=col, lw=3, mark_rel_y=mark_rel_y, mark=sig_str, fontsize=fontsize ) highest[l] = np.nanmax([highest[l], y_pos]) if y_pos > ylims[1]: sub_ax.set_ylim([ylims[0], y_pos * 1.1]) return highest
def plot_idx_correlations(idx_corr_df, permpar, figpar, permute="sess", corr_type="corr", title=None, small=True): """ plot_idx_correlations(idx_corr_df, permpar, figpar) Plots ROI USI index correlations across sessions. Required args: - idx_corr_df (pd.DataFrame): dataframe with one row per line/plane, and the following columns, in addition to the basic sess_df columns: for session comparisons, e.g. 1v2 - {}v{}{norm_str}_corrs (float): intersession ROI index correlations - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI index correlation standard deviation - {}v{}_null_CIs (list): adjusted null CI for intersession ROI index correlations - {}v{}_raw_p_vals (float): p-value for intersession correlations - {}v{}_p_vals (float): p-value for intersession correlations, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - permute (bool): type of permutation to due ("tracking", "sess" or "all") default: "sess" - corr_type (str): type of correlation run, i.e. "corr" or "R_sqr" default: "corr" - title (str): plot title default: None - small (bool): if True, smaller subplots are plotted default: True Returns: - ax (2D array): array of subplots """ norm = False if permute in ["sess", "all"]: corr_type = f"diff_{corr_type}" if corr_type == "diff_corr": norm = True title = title.replace("Correlations", "Normalized correlations") norm_str = "_norm" if norm else "" sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm) n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2) # multiple of 2 figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg", n_sub=int(n_pairs / 2)) figpar["init"]["ncols"] = n_pairs figpar["init"]["sharey"] = "row" figpar["init"]["gs"] = {"hspace": 0.25} if small: figpar["init"]["subplot_wid"] = 2.7 figpar["init"]["subplot_hei"] = 4.21 figpar["init"]["gs"]["wspace"] = 0.2 else: figpar["init"]["subplot_wid"] = 3.3 figpar["init"]["subplot_hei"] = 4.71 figpar["init"]["gs"]["wspace"] = 0.3 fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.98, weight="bold") plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm) lines = [None, None] comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) lines[li] = line.split("-")[0].replace("23", "2/3") if len(lp_df) != 1: raise RuntimeError("Expected only one row per line/plane.") row = lp_df.loc[lp_df.index[0]] lp_sig_str = f"{linpla_name:6}:" for s, sess_pair in enumerate(sess_pairs): sub_ax = ax[pl, s] if s == 0: sub_ax.set_ylim(plane_pts[pl]) col_base = f"{sess_pair[0]}v{sess_pair[1]}" CI = row[f"{col_base}_null_CIs"] extr = np.asarray([CI[0], CI[2]]) plot_util.plot_CI(sub_ax, extr, med=CI[1], x=li, width=0.45, med_rat=0.025) y = row[f"{col_base}{norm_str}_corrs"] err = row[f"{col_base}{norm_str}_corr_stds"] plot_util.plot_ufo(sub_ax, x=li, y=y, err=err, color=col, capsize=8) # add significance markers p_val = row[f"{col_base}_p_vals"] side = np.sign(y - CI[1]) sensitivity = misc_analys.get_sensitivity(permpar) sig_str = misc_analys.get_sig_symbol(p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"]) if len(sig_str): high = np.max([CI[-1], y + err]) plot_util.add_signif_mark(sub_ax, li, high, rel_y=0.1, color=col, fontsize=24, mark=sig_str) sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: " lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}" logger.info(lp_sig_str, extra={"spacing": TAB}) # Add plane, line info to plots sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=["", ""], xticks=[0, 1], ylab="", xlab="", kind="traces") xs = np.arange(len(lines)) pad_x = 0.6 * (xs[1] - xs[0]) for row_n in range(len(ax)): for col_n in range(len(ax[row_n])): sub_ax = ax[row_n, col_n] sub_ax.tick_params(axis="x", which="both", bottom=False) sub_ax.set_xticks(xs) sub_ax.set_xticklabels(lines, weight="bold") sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x]) sub_ax.set_ylim(plane_pts[row_n]) sub_ax.set_yticks(plane_pts[row_n]) if row_n == 0: if col_n < len(sess_pairs): s1, s2 = sess_pairs[col_n] sess_pair_title = f"Session {s1} v {s2}" sub_ax.set_title(sess_pair_title, fontweight="bold", y=1.07) sub_ax.spines["bottom"].set_visible(True) plot_util.set_interm_ticks(ax[row_n], 3, axis="y", weight="bold", share=False, update_ticks=True) return ax
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, title=None): """ plot_perc_sig_usis(perc_sig_df, analyspar, figpar) Plots percentage of significant USIs. Required args: - perc_sig_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: for sig in ["lo", "hi"]: for low vs high ROI indices - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100) - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation over percent significant ROIs - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent sig. ROIs - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for percent sig. ROIs - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. ROIs, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - by_mouse (bool): if True, plotting is done per mouse default: False - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ perc_sig_df = perc_sig_df.copy(deep=True) nanpol = None if analyspar["rem_bad"] else "omit" sess_ns = perc_sig_df["sess_ns"].unique() if len(sess_ns) != 1: raise NotImplementedError( "Plotting function implemented for 1 session only." ) figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, sharex=True, sharey=True) figpar["init"]["sharey"] = True figpar["init"]["subplot_wid"] = 3.4 figpar["init"]["gs"] = {"wspace": 0.18} if by_mouse: figpar["init"]["subplot_hei"] = 8.4 else: figpar["init"]["subplot_hei"] = 3.5 fig, ax = plot_util.init_fig(2, **figpar["init"]) if title is not None: y = 0.98 if by_mouse else 1.07 fig.suptitle(title, y=y, weight="bold") tail_order = ["Low tail", "High tail"] tail_keys = ["lo", "hi"] chance = permpar["p_val"] / 2 * 100 ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40) n_linpla = plot_helper_fcts.N_LINPLA comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for t, (tail, key) in enumerate(zip(tail_order, tail_keys)): sub_ax = ax[0, t] sub_ax.set_title(tail, fontweight="bold") sub_ax.set_ylim(ylims) # replace bottom spine with line at 0 sub_ax.spines['bottom'].set_visible(False) sub_ax.axhline(y=0, c="k", lw=4.0) data_key = f"perc_sig_{key}_idxs" CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan) CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan) tail_sig_str = f"{tail:9}:" linpla_names = [] for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) x_index = 2 * li + pl linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) linpla_names.append(linpla_name) if len(lp_df) == 0: continue elif len(lp_df) > 1 and not by_mouse: raise RuntimeError("Expected a single row per line/plane.") lp_df = lp_df.sort_values("mouse_ns") # sort by mouse df_indices = lp_df.index.tolist() if by_mouse: # plot means or medians per mouse mouse_data = lp_df[data_key].to_numpy() mouse_cols = plot_util.get_hex_color_range( len(lp_df), col=col, interval=plot_helper_fcts.MOUSE_COL_INTERVAL ) mouse_data_mean = math_util.mean_med( mouse_data, stats=analyspar["stats"], nanpol=nanpol ) CI_dummy = np.repeat(mouse_data_mean, 2) plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, x=x_index, width=0.6, med_col=col, med_rat=0.01) else: # collect confidence interval data row = lp_df.loc[df_indices[0]] mouse_cols = [col] CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[ np.asarray([0, 2]) ] CI_meds[x_index] = row[f"{data_key}_null_CIs"][1] if by_mouse: perc_p_vals = [] rel_y = 0.05 else: tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: " rel_y = 0.1 for df_i, mouse_col in zip(df_indices, mouse_cols): # plot UFOs err = None no_line = True if not by_mouse: err = perc_sig_df.loc[df_i, f"{data_key}_stds"] no_line = False # indicate bootstrapped error with wider capsize plot_util.plot_ufo( sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err, color=mouse_col, capsize=8, no_line=no_line ) # add significance markers p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"] perc = perc_sig_df.loc[df_i, data_key] nrois = np.sum(perc_sig_df.loc[df_i, "nrois"]) side = np.sign(perc - chance) sensitivity = misc_analys.get_binom_sensitivity( nrois, null_perc=chance, side=side ) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) if len(sig_str): perc_high = perc + err if err is not None else perc plot_util.add_signif_mark(sub_ax, x_index, perc_high, rel_y=rel_y, color=mouse_col, fontsize=24, mark=sig_str) if by_mouse: perc_p_vals.append( (int(np.around(perc)), p_val, sig_str) ) else: tail_sig_str = ( f"{tail_sig_str}{p_val:.5f}{sig_str:3}" ) if by_mouse: # sort p-value logging by percentage value tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: " order = np.argsort([vals[0] for vals in perc_p_vals]) for i in order: perc, p_val, sig_str = perc_p_vals[i] perc_str = f"(~{perc}%)" tail_sig_str = ( f"{tail_sig_str}{TAB}{perc_str:6} " f"{p_val:.5f}{sig_str:3}" ) # add chance information if by_mouse: sub_ax.axhline( y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, zorder=-12 ) else: plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12) logger.info(tail_sig_str, extra={"spacing": TAB}) for sub_ax in fig.axes: sub_ax.tick_params(axis="x", which="both", bottom=False) plot_util.set_ticks( sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2) sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold") ax[0, 0].set_ylabel("%", fontweight="bold") plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True) # adjustment if tick interval is repeated in the negative if ax[0, 0].get_ylim()[0] < 0: ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]]) return ax
def plot_roi_tracking(roi_mask_df, figpar, title=None): """ plot_roi_tracking(roi_mask_df, figpar) Plots ROI tracking examples, for different session permutations, and union across permutations. Required args: - roi_mask_df (pd.DataFrame in dict format): dataframe with a row for each mouse, and the following columns, in addition to the basic sess_df columns: - "roi_mask_shapes" (list): shape into which ROI mask indices index (sess x hei x wid) - "union_n_conflicts" (int): number of conflicts after union for "union", "fewest" and "most" tracked ROIs: - "{}_registered_roi_mask_idxs" (list): list of mask indices, registered across sessions, for each session (flattened across ROIs) ((sess, hei, wid) x val), ordered by {}_sess_ns if "fewest" or "most" - "{}_n_tracked" (int): number of tracked ROIs for "fewest", "most" tracked ROIs: - "{}_sess_ns" (list): ordered session number - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ if len(roi_mask_df) != 1: raise ValueError("Expected only one row in roi_mask_df") roi_mask_row = roi_mask_df.loc[roi_mask_df.index[0]] columns = ["fewest", "most", "", "union"] figpar["init"]["ncols"] = len(columns) figpar["init"]["sharex"] = False figpar["init"]["sharey"] = False figpar["init"]["subplot_hei"] = 5.05 figpar["init"]["subplot_wid"] = 5.05 figpar["init"]["gs"] = {"wspace": 0.06} # MUST ADJUST if anything above changes [right, bottom, width, height] new_axis_coords = [0.905, 0.125, 0.06, 0.74] fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"]) sub_ax_scale = fig.add_axes(new_axis_coords) plot_util.remove_axis_marks(sub_ax_scale) sub_ax_scale.spines["left"].set_visible(True) if title is not None: fig.suptitle(title, y=1.05, weight="bold") sess_cols = get_sess_cols(roi_mask_df) alpha = 0.6 for c, column in enumerate(columns): sub_ax = ax[0, c] if c == 0: lp_col = plot_helper_fcts.get_line_plane_idxs( roi_mask_row["lines"], roi_mask_row["planes"] )[2] lp_name = plot_helper_fcts.get_line_plane_name( roi_mask_row["lines"], roi_mask_row["planes"] ) sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col) log_info = f"Conflicts and matches for a {lp_name} example:" if column == "": sub_ax.set_axis_off() subplot_title = \ " Union - conflicts\n... ====================>" sub_ax.set_title(subplot_title, fontweight="bold", y=0.5) continue else: plot_util.remove_axis_marks(sub_ax) for spine in ["right", "left", "top", "bottom"]: sub_ax.spines[spine].set_visible(True) if column in ["fewest", "most"]: y = 1.01 ord_sess_ns = roi_mask_row[f"{column}_sess_ns"] ord_sess_ns_str = ", ".join([str(n) for n in ord_sess_ns]) n_matches = int(roi_mask_row[f"{column}_n_tracked"]) subplot_title = f"{n_matches} matches\n(sess {ord_sess_ns_str})" log_info = (f"{log_info}\n{TAB}" f"{column.capitalize()} matches (sess {ord_sess_ns_str}): " f"{n_matches}") elif column == "union": y = 1.04 ord_sess_ns = roi_mask_row["sess_ns"] n_union = int(roi_mask_row[f"{column}_n_tracked"]) n_conflicts = int(roi_mask_row[f"{column}_n_conflicts"]) n_matches = n_union - n_conflicts subplot_title = f"{n_matches} matches" log_info = (f"{log_info}\n{TAB}" "Union - conflicts: " f"{n_union} - {n_conflicts} = {n_matches} matches" ) sub_ax.set_title(subplot_title, fontweight="bold", y=y) roi_masks = create_sess_roi_masks( roi_mask_row, mask_key=f"{column}_registered_roi_mask_idxs" ) for sess_n in roi_mask_row["sess_ns"]: col = sess_cols[int(sess_n)] s = ord_sess_ns.index(sess_n) add_roi_mask(sub_ax, roi_masks[s], col=col, alpha=alpha) # add scale marker hei_len = roi_mask_row["roi_mask_shapes"][1] add_scale_marker( sub_ax_scale, side_len=hei_len, ori="vertical", quadrant=3, fontsize=20 ) logger.info(log_info, extra={"spacing": "\n"}) # add legend add_sess_col_leg( ax[0, columns.index("")], sess_cols, bbox_to_anchor=(0.67, 0.3), alpha=alpha ) return ax
def plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar, title=None): """ plot_roi_masks_overlayed_with_proj(roi_mask_df, figpar) Plots ROI mask contours overlayed over imaging planes, and ROI masks overlayed over each other across sessions. Required args: - roi_mask_df (pd.DataFrame in dict format): dataframe with a row for each mouse, and the following columns, in addition to the basic sess_df columns: - "max_projections" (list): pixel intensities of maximum projection for the plane (hei x wid) - "registered_roi_mask_idxs" (list): list of mask indices, registered across sessions, for each session (flattened across ROIs) ((sess, hei, wid) x val) - "roi_mask_idxs" (list): list of mask indices for each session, and each ROI (sess x (ROI, hei, wid) x val) (not registered) - "roi_mask_shapes" (list): shape into which ROI mask indices index (sess x hei x wid) - "crop_fact" (num): factor by which to crop masks (> 1) - "shift_prop_hei" (float): proportion by which to shift cropped mask center vertically from left edge [0, 1] - "shift_prop_wid" (float): proportion by which to shift cropped mask center horizontally from left edge [0, 1] - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ n_lines = len(roi_mask_df["lines"].unique()) n_planes = len(roi_mask_df["planes"].unique()) sess_cols = get_sess_cols(roi_mask_df) n_sess = len(sess_cols) n_cols = n_sess * n_lines figpar = sess_plot_util.fig_init_linpla(figpar) figpar["init"]["sharex"] = False figpar["init"]["sharey"] = False figpar["init"]["subplot_hei"] = 2.3 figpar["init"]["subplot_wid"] = 2.3 figpar["init"]["gs"] = {"wspace": 0.2, "hspace": 0.2} figpar["init"]["ncols"] = n_cols fig, ax = plot_util.init_fig(n_cols * n_planes * 2, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.93, weight="bold") crop = "crop_fact" in roi_mask_df.columns sess_cols = get_sess_cols(roi_mask_df) alpha = 0.6 raster_zorder = -12 for (line, plane), lp_mask_df in roi_mask_df.groupby(["lines", "planes"]): li, pl, _, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) lp_col = plot_helper_fcts.get_line_plane_idxs(line, plane)[2] lp_name = plot_helper_fcts.get_line_plane_name(line, plane) if len(lp_mask_df) != 1: raise RuntimeError("Expected only one row per line/plane.") lp_row = lp_mask_df.loc[lp_mask_df.index[0]] # identify subplots base_row = (pl % n_planes) * n_planes base_col = (li % n_lines) * n_lines ax_grp = ax[base_row : base_row + 2, base_col : base_col + n_sess + 1] # add imaging planes and masks imaging_planes = add_proj_and_roi_masks( ax_grp, lp_row, sess_cols, crop=crop, alpha=alpha, proj_zorder=raster_zorder - 1 ) # add markings shared_row = base_row + 1 shared_col = base_col + int((n_sess - 1) // 2) shared_sub_ax = ax[shared_row, shared_col] if shared_col == 0: shared_sub_ax.set_ylabel(lp_name, fontweight="bold", color=lp_col) else: lp_sub_ax = ax[shared_row, 0] lp_sub_ax.set_xlim([0, 1]) lp_sub_ax.set_ylim([0, 1]) lp_sub_ax.text( 0.5, 0.5, lp_name, fontweight="bold", color=lp_col, ha="center", va="center", fontsize="x-large" ) # add scale bar if n_sess < 2: raise NotImplementedError( "Scale bar placement not implemented for fewer than 2 " "sessions." ) scale_ax = ax[shared_row, -1] wid_len = imaging_planes[0].shape[-1] add_scale_marker( scale_ax, side_len=wid_len, ori="horizontal", quadrant=1, fontsize=20 ) logger.info("Rasterizing imaging plane images...", extra={"spacing": TAB}) for i in range(ax.shape[0]): for j in range(ax.shape[1]): sub_ax = ax[i, j] plot_util.remove_axis_marks(sub_ax) if not(i % 2): sub_ax.set_rasterization_zorder(raster_zorder) # add legend if n_sess < 2: raise NotImplementedError( "Legend placement not implemented for fewer than 2 sessions." ) add_sess_col_leg( ax[-1, -1], sess_cols, bbox_to_anchor=(1, 0.6), alpha=alpha, fontsize="small" ) return ax
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, title=None, seed=None): """ plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar) Plots pupil and running block differences. Required args: - block_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - run_block_diffs (list): running velocity differences per block - run_raw_p_vals (float): uncorrected p-value for differences within sessions - run_p_vals (float): p-value for differences within sessions, corrected for multiple comparisons and tails - pupil_block_diffs (list): for pupil diameter differences per block - pupil_raw_p_vals (list): uncorrected p-value for differences within sessions - pupil_p_vals (list): p-value for differences within sessions, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None - seed (int): seed value to use. (-1 treated as None) default: None Returns: - ax (2D array): array of subplots """ if analyspar["scale"]: raise NotImplementedError( "Expected running and pupil data to not be scaled." ) if len(block_df["sess_ns"].unique()) != 1: raise NotImplementedError( "'block_df' should only contain one session number." ) nanpol = None if analyspar["rem_bad"] else "omit" sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) datatypes = ["run", "pupil"] datatype_strs = ["Running velocity", "Pupil diameter"] n_datatypes = len(datatypes) fig, ax = plt.subplots( 1, n_datatypes, figsize=(12.7, 4), squeeze=False, gridspec_kw={"wspace": 0.22} ) if title is not None: fig.suptitle(title, y=1.2, weight="bold") logger.info(f"{comp_info}:", extra={"spacing": "\n"}) corr_str = "corr." if permpar["multcomp"] else "raw" for d, datatype in enumerate(datatypes): datatype_sig_str = f"{datatype_strs[d]:16}:" sub_ax = ax[0, d] lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)] xs, all_data, cols, dashes, p_val_texts = [], [], [], [], [] for (line, plane), lp_df in block_df.groupby(["lines", "planes"]): x, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane, flat=True ) line_plane_name = plot_helper_fcts.get_line_plane_name( line, plane ) lp_names[int(x)] = line_plane_name if len(lp_df) == 1: row_idx = lp_df.index[0] elif len(lp_df) > 1: raise RuntimeError("Expected 1 row per line/plane/session.") lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"] # get p-value information p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"] side = np.sign( math_util.mean_med( lp_data, stats=analyspar["stats"], nanpol=nanpol ) ) sig_str = misc_analys.get_sig_symbol( p_val_corr, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) p_val_text = f"{p_val_corr:.2f}{sig_str}" datatype_sig_str = ( f"{datatype_sig_str}{TAB}{line_plane_name}: " f"{p_val_corr:.5f}{sig_str:3}" ) # collect information xs.append(x) all_data.append(lp_data) cols.append(col) dashes.append(dash) p_val_texts.append(p_val_text) plot_violin_data( sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed ) # edit ticks sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA)) sub_ax.set_xticklabels(lp_names, fontweight="bold") sub_ax.tick_params(axis="x", which="both", bottom=False) plot_util.expand_lims(sub_ax, axis="y", prop=0.1) plot_util.set_interm_ticks( np.asarray(sub_ax), n_ticks=3, axis="y", share=False, fontweight="bold", update_ticks=True ) for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)): ylim_range = np.diff(sub_ax.get_ylim()) y = sub_ax.get_ylim()[1] + ylim_range * 0.08 ha = "center" if d == 0 and i == 0: x += 0.2 p_val_text = f"{corr_str} p-val. {p_val_text}" ha = "right" sub_ax.text( x, y, p_val_text, fontsize=20, weight="bold", ha=ha ) logger.info(datatype_sig_str, extra={"spacing": TAB}) # add labels/titles for d, datatype in enumerate(datatypes): sub_ax = ax[0, d] sub_ax.axhline( y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5 ) if d == 0: ylabel = "Trial differences\nU-G - D-G" sub_ax.set_ylabel(ylabel, weight="bold") if datatype == "run": title = "Running velocity (cm/s)" elif datatype == "pupil": title = "Pupil diameter (mm)" sub_ax.set_title(title, weight="bold", y=1.2) return ax