def plot_ex_gabor_traces(ex_traces_df, stimpar, figpar, title=None): """ plot_ex_gabor_traces(ex_traces_df, stimpar, figpar) Plots example Gabor traces. Required args: - ex_traces_df (pd.DataFrame): dataframe with a row for each ROI, and the following columns, in addition to the basic sess_df columns: - time_values (list): values for each frame, in seconds - roi_ns (list): selected ROI number - traces_sm (list): selected ROI sequence traces, smoothed, with dims: seq x frames - trace_stat (list): selected ROI trace mean or median - stimpar (dict): dictionary with keys of StimPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ if stimpar["stimtype"] != "gabors": raise ValueError("Expected stimpar['stimtype'] to be 'gabors'.") group_columns = ["lines", "planes"] n_per = np.max( [len(lp_df) for _, lp_df in ex_traces_df.groupby(group_columns)] ) per_rows, per_cols = math_util.get_near_square_divisors(n_per) n_per = per_rows * per_cols figpar = sess_plot_util.fig_init_linpla( figpar, kind="traces", n_sub=per_rows ) figpar["init"]["subplot_hei"] = 1.36 figpar["init"]["subplot_wid"] = 2.47 figpar["init"]["ncols"] = per_cols * 2 fig, ax = plot_util.init_fig( plot_helper_fcts.N_LINPLA * n_per, **figpar["init"] ) if title is not None: fig.suptitle(title, y=1.03, weight="bold") ylims = np.full(ax.shape + (2, ), np.nan) logger.info("Plotting individual traces...", extra={"spacing": TAB}) raster_zorder = -12 for (line, plane), lp_df in ex_traces_df.groupby(["lines", "planes"]): li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane) for i, idx in enumerate(lp_df.index): row_idx = int(pl * per_rows + i % per_rows) col_idx = int(li * per_cols + i // per_rows) sub_ax = ax[row_idx, col_idx] ylims[row_idx, col_idx] = plot_ex_gabor_roi_traces( sub_ax, lp_df.loc[idx], col=col, dash=dash, zorder=raster_zorder - 1 ) time_values = np.asarray(lp_df.loc[lp_df.index[-1], "time_values"]) sess_plot_util.format_linpla_subaxes(ax, fluor="dff", area=False, datatype="roi", sess_ns=None, xticks=None, kind="traces", modif_share=False) # fix x ticks and lims for sub_ax in ax.reshape(-1): xlims = [time_values[0], time_values[-1]] xticks = np.linspace(*xlims, 6) sub_ax.set_xticks(xticks) plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold", skip=False) for sub_ax in ax.reshape(-1): sub_ax.set_xlim(xlims) # reset y limits for r in range(ax.shape[0]): for c in range(ax.shape[1]): if not np.isfinite(ylims[r, c].sum()): continue ax[r, c].set_ylim(ylims[r, c]) plot_util.set_interm_ticks( ax, 2, axis="y", share=False, weight="bold", update_ticks=True ) # rasterize the gray lines logger.info("Rasterizing individual traces...", extra={"spacing": TAB}) for sub_ax in ax.reshape(-1): sub_ax.set_rasterization_zorder(raster_zorder) return ax
def plot_stim_data_df(stim_data_df, stimpar, permpar, figpar, pop_stats=True, title=None): """ plot_stim_data_df(stim_data_df, stimpar, permpar, figpar) Plots stimulus comparison data. Required args: - stim_stats_df (pd.DataFrame): dataframe with one row per line/plane and one for all line/planes together, and the basic sess_df columns, in addition to, for each stimtype: - stimtype (list): absolute fractional change statistics (me, err) - raw_p_vals (float): uncorrected p-value for data differences between stimulus types - p_vals (float): p-value for data differences between stimulus types, corrected for multiple comparisons and tails - stimpar (dict): dictionary with keys of StimPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - pop_stats (bool): if True, analyses are run on population statistics, and not individual tracked ROIs default: True - title (str): plot title default: None Returns: - ax (2D array): array of subplots (does not include added subplot for all line/plane data together) """ figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg") figpar["init"]["subplot_wid"] = 2.1 figpar["init"]["subplot_hei"] = 4.2 figpar["init"]["gs"] = {"hspace": 0.20, "wspace": 0.3} figpar["init"]["sharey"] = "row" fig, ax = plot_util.init_fig(plot_helper_fcts.N_LINPLA, **figpar["init"]) fig.suptitle(title, y=0.98, weight="bold") sub_ax_all = fig.add_axes([1.05, 0.11, 0.3, 0.77]) stimtypes = stimpar["stimtype"][:] # deep copy # indicate bootstrapped error with wider capsize capsize = 8 if pop_stats else 6 lp_data = [] cols = [] for (line, plane), lp_df in stim_data_df.groupby(["lines", "planes"]): x = [0, 1] data = np.vstack( [lp_df[stimtypes[0]].tolist(), lp_df[stimtypes[1]].tolist()] ).T y = data[0] err = data[1:] if line != "all" and plane != "all": li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane ) alpha = 0.5 sub_ax = ax[pl, li] lp_data.append(y) cols.append(col) else: col = plot_helper_fcts.NEARBLACK dash = None alpha = 0.2 sub_ax = sub_ax_all sub_ax.set_title("all", fontweight="bold") plot_util.plot_bars( sub_ax, x, y=y, err=err, width=0.5, lw=None, alpha=alpha, color=col, ls=dash, capsize=capsize ) # add dots to the all subplot x_vals = np.asarray([-0.17, 0.25, -0.25, 0.17]) # to spread dots out lw = 4 ms = 200 for s, _ in enumerate(stimtypes): lp_stim_data = [data[s] for data in lp_data] sorter = np.argsort(lp_stim_data) for i, idx in enumerate(sorter): x_val = s + x_vals[i] # white behind sub_ax_all.scatter( x=x_val, y=lp_stim_data[idx], s=ms, linewidth=lw, alpha=0.8, color="white", zorder=10 ) # colored dots sub_ax_all.scatter( x=x_val, y=lp_stim_data[idx], s=ms, alpha=0.6, linewidth=0, color=cols[idx], zorder=11 ) # dot borders sub_ax_all.scatter( x=x_val, y=lp_stim_data[idx], s=ms, color="None", edgecolor=cols[idx], linewidth=lw, alpha=1, zorder=12 ) # add between stim significance add_between_stim_sig(ax, sub_ax_all, stim_data_df, permpar) # add plane, line info to plots sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=None, planes=["", ""], xticks=[0, 1], ylab="Absolute fractional change", kind="reg", xlab="" ) # adjust plot details stimtype_names = stimtypes[:] stimtype_names[stimtypes.index("visflow")] = "visual\nflow" for sub_ax in fig.axes: y_max = sub_ax.get_ylim()[1] sub_ax.set_ylim([0, y_max]) sub_ax.set_xticks([0, 1]) sub_ax.set_xticklabels( stimtypes, weight="bold", rotation=45, ha="right" ) sub_ax.tick_params(axis="x", bottom=False) sub_ax_all.set_xlim(ax[0, 0].get_xlim()) plot_util.set_interm_ticks( np.asarray(sub_ax_all), 4, axis="y", share=False, weight="bold" ) return ax
def plot_corr_ex_data_scatterplot(sub_ax, idx_corr_norm_row, corr_name="1v2", col="k"): """ plot_corr_ex_data_scatterplot(sub_ax, idx_corr_norm_row) Plots example random correlation data in a scatterplot showing real versus example randomly generated data. Required args: - sub_ax (plt subplot): subplot - idx_corr_norm_df (pd.Series): dataframe series with the following columns, in addition to the basic sess_df columns: for a specific session comparison, e.g. 1v2 - {}v{}_corrs (float): unnormalized intersession ROI index correlations - {}v{}_norm_corrs (float): normalized intersession ROI index correlations - {}v{}_rand_corr_meds (float): median of randomized correlations - {}v{}_rand_corrs_binned (list): binned random unnormalized intersession ROI index correlations - {}v{}_rand_corrs_bin_edges (list): bins edges Optional args: - corr_name (str): session pair correlation name, used in series columns default: "1v2" - col (str): color for real data """ sess_pair = corr_name.split("v") x_perm, y_perm = np.asarray(idx_corr_norm_row[f"{corr_name}_rand_ex"]) raw_rand_corr = idx_corr_norm_row[f"{corr_name}_rand_ex_corrs"] sub_ax.scatter(x_perm, y_perm, color="gray", alpha=0.4, marker="d", lw=2, label=f"Raw random corr: {raw_rand_corr:.2f}") x_data, y_data = np.asarray(idx_corr_norm_row[f"{corr_name}_corr_data"]) raw_corr = idx_corr_norm_row[f"{corr_name}_corrs"] sub_ax.scatter(x_data, y_data, color=col, alpha=0.4, lw=2, label=f"Raw corr: {raw_corr:.2f}") sub_ax.set_ylabel( f"USI diff. between\nsession {sess_pair[0]} and {sess_pair[1]}", fontweight="bold") sub_ax.set_xlabel(f"Session {sess_pair[0]} USIs", fontweight="bold") sub_ax.legend() plot_util.set_interm_ticks(np.asarray(sub_ax), n_ticks=4, axis="x", share=False, fontweight="bold")
def plot_sess_traces(data_df, analyspar, sesspar, figpar, trace_col="trace_stats", row_col="sess_ns", row_order=None, split="by_exp", title=None, size="reg"): """ plot_sess_traces(data_df, analyspar, sesspar, figpar) Plots traces from dataframe. Required args: - data_df (pd.DataFrame): traces data frame with, in addition to the basic sess_df columns, columns specified by trace_col, row_col, and a "time_values" column - analyspar (dict): dictionary with keys of AnalysPar namedtuple - sesspar (dict): dictionary with keys of SessPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - trace_col (str): dataframe column containing trace statistics, as split x ROIs x frames x stats default: "trace_stats" - row_col (str): dataframe column specifying the variable that defines rows within each line/plane default: "sess_ns" - row_order (list): ordered list specifying the order in which to plot from row_col. If None, automatic sorting order is used. default: None - split (str): data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or "stim_offset" default: False - title (str): plot title default: None - size (str): subplot sizes default: "reg" Returns: - ax (2D array): array of subplots """ # retrieve session numbers, and infer row_order, if necessary sess_ns = None if row_col == "sess_ns": sess_ns = row_order if row_order is None: row_order = misc_analys.get_sess_ns(sesspar, data_df) elif row_order is None: row_order = data_df[row_col].unique() figpar = sess_plot_util.fig_init_linpla( figpar, kind="traces", n_sub=len(row_order), sharey=False ) if size == "small": figpar["init"]["subplot_hei"] = 1.51 figpar["init"]["subplot_wid"] = 3.7 elif size == "wide": figpar["init"]["subplot_hei"] = 1.36 figpar["init"]["subplot_wid"] = 4.8 figpar["init"]["gs"] = {"wspace": 0.3, "hspace": 0.5} elif size == "reg": figpar["init"]["subplot_hei"] = 1.36 figpar["init"]["subplot_wid"] = 3.4 else: gen_util.accepted_values_error("size", size, ["small", "wide", "reg"]) fig, ax = plot_util.init_fig(len(row_order) * 4, **figpar["init"]) if title is not None: fig.suptitle(title, y=1.0, weight="bold") for (line, plane), lp_df in data_df.groupby(["lines", "planes"]): li, pl, col, dash = plot_helper_fcts.get_line_plane_idxs(line, plane) for r, row_val in enumerate(row_order): rows = lp_df.loc[lp_df[row_col] == row_val] if len(rows) == 0: continue elif len(rows) > 1: raise RuntimeError( "Expected row_order instances to be unique per line/plane." ) row = rows.loc[rows.index[0]] sub_ax = ax[r + pl * len(row_order), li] if line == "L2/3-Cux2": exp_col = "darkgray" # oddly, lighter than gray else: exp_col = "gray" plot_traces( sub_ax, row["time_values"], row[trace_col], split=split, col=col, ls=dash, exp_col=exp_col, lab=False ) for sub_ax in ax.reshape(-1): plot_util.set_minimal_ticks(sub_ax, axis="y") sess_plot_util.format_linpla_subaxes(ax, fluor=analyspar["fluor"], area=False, datatype="roi", sess_ns=sess_ns, xticks=None, kind="traces", modif_share=False) # fix x ticks and lims plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold") xlims = [np.min(row["time_values"]), np.max(row["time_values"])] if split != "by_exp": xlims = [-xlims[1], xlims[1]] sub_ax.set_xlim(xlims) return ax
def plot_idx_correlations(idx_corr_df, permpar, figpar, permute="sess", corr_type="corr", title=None, small=True): """ plot_idx_correlations(idx_corr_df, permpar, figpar) Plots ROI USI index correlations across sessions. Required args: - idx_corr_df (pd.DataFrame): dataframe with one row per line/plane, and the following columns, in addition to the basic sess_df columns: for session comparisons, e.g. 1v2 - {}v{}{norm_str}_corrs (float): intersession ROI index correlations - {}v{}{norm_str}_corr_stds (float): bootstrapped intersession ROI index correlation standard deviation - {}v{}_null_CIs (list): adjusted null CI for intersession ROI index correlations - {}v{}_raw_p_vals (float): p-value for intersession correlations - {}v{}_p_vals (float): p-value for intersession correlations, corrected for multiple comparisons and tails - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - permute (bool): type of permutation to due ("tracking", "sess" or "all") default: "sess" - corr_type (str): type of correlation run, i.e. "corr" or "R_sqr" default: "corr" - title (str): plot title default: None - small (bool): if True, smaller subplots are plotted default: True Returns: - ax (2D array): array of subplots """ norm = False if permute in ["sess", "all"]: corr_type = f"diff_{corr_type}" if corr_type == "diff_corr": norm = True title = title.replace("Correlations", "Normalized correlations") norm_str = "_norm" if norm else "" sess_pairs = get_sorted_sess_pairs(idx_corr_df, norm=norm) n_pairs = int(np.ceil(len(sess_pairs) / 2) * 2) # multiple of 2 figpar = sess_plot_util.fig_init_linpla(figpar, kind="reg", n_sub=int(n_pairs / 2)) figpar["init"]["ncols"] = n_pairs figpar["init"]["sharey"] = "row" figpar["init"]["gs"] = {"hspace": 0.25} if small: figpar["init"]["subplot_wid"] = 2.7 figpar["init"]["subplot_hei"] = 4.21 figpar["init"]["gs"]["wspace"] = 0.2 else: figpar["init"]["subplot_wid"] = 3.3 figpar["init"]["subplot_hei"] = 4.71 figpar["init"]["gs"]["wspace"] = 0.3 fig, ax = plot_util.init_fig(n_pairs * 2, **figpar["init"]) if title is not None: fig.suptitle(title, y=0.98, weight="bold") plane_pts = get_idx_corr_ylims(idx_corr_df, norm=norm) lines = [None, None] comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for (line, plane), lp_df in idx_corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) lines[li] = line.split("-")[0].replace("23", "2/3") if len(lp_df) != 1: raise RuntimeError("Expected only one row per line/plane.") row = lp_df.loc[lp_df.index[0]] lp_sig_str = f"{linpla_name:6}:" for s, sess_pair in enumerate(sess_pairs): sub_ax = ax[pl, s] if s == 0: sub_ax.set_ylim(plane_pts[pl]) col_base = f"{sess_pair[0]}v{sess_pair[1]}" CI = row[f"{col_base}_null_CIs"] extr = np.asarray([CI[0], CI[2]]) plot_util.plot_CI(sub_ax, extr, med=CI[1], x=li, width=0.45, med_rat=0.025) y = row[f"{col_base}{norm_str}_corrs"] err = row[f"{col_base}{norm_str}_corr_stds"] plot_util.plot_ufo(sub_ax, x=li, y=y, err=err, color=col, capsize=8) # add significance markers p_val = row[f"{col_base}_p_vals"] side = np.sign(y - CI[1]) sensitivity = misc_analys.get_sensitivity(permpar) sig_str = misc_analys.get_sig_symbol(p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"]) if len(sig_str): high = np.max([CI[-1], y + err]) plot_util.add_signif_mark(sub_ax, li, high, rel_y=0.1, color=col, fontsize=24, mark=sig_str) sess_str = f"S{sess_pair[0]}v{sess_pair[1]}: " lp_sig_str = f"{lp_sig_str}{TAB}{sess_str}{p_val:.5f}{sig_str:3}" logger.info(lp_sig_str, extra={"spacing": TAB}) # Add plane, line info to plots sess_plot_util.format_linpla_subaxes(ax, datatype="roi", lines=["", ""], xticks=[0, 1], ylab="", xlab="", kind="traces") xs = np.arange(len(lines)) pad_x = 0.6 * (xs[1] - xs[0]) for row_n in range(len(ax)): for col_n in range(len(ax[row_n])): sub_ax = ax[row_n, col_n] sub_ax.tick_params(axis="x", which="both", bottom=False) sub_ax.set_xticks(xs) sub_ax.set_xticklabels(lines, weight="bold") sub_ax.set_xlim([xs[0] - pad_x, xs[-1] + pad_x]) sub_ax.set_ylim(plane_pts[row_n]) sub_ax.set_yticks(plane_pts[row_n]) if row_n == 0: if col_n < len(sess_pairs): s1, s2 = sess_pairs[col_n] sess_pair_title = f"Session {s1} v {s2}" sub_ax.set_title(sess_pair_title, fontweight="bold", y=1.07) sub_ax.spines["bottom"].set_visible(True) plot_util.set_interm_ticks(ax[row_n], 3, axis="y", weight="bold", share=False, update_ticks=True) return ax
def add_scatterplot_markers(sub_ax, permpar, regr_corr, rand_corr_med, slope, intercept, p_val, col="k", diffs=False): """ add_scatterplot_markers(sub_ax, permpar, regr_corr, rand_corr_med, slope, intercept) Adds markers to a scatterplot plot (regression line, quadrant shading, slope and significance text). Required args: - sub_ax (plt subplot): subplot - permpar (dict): dictionary with keys of PermPar namedtuple - regr_corr (float): regression correlation - rand_corr_med (float): median of the random correlations - slope (float): regression slope - intercept (float): regression intercept - p_val (float): correlation p-value for significance marker Optional args: - col (str): regression line colour default: "k" - diffs (bool): if True, and a horizontal line at 0 is plotted instead of the identity line default: False """ # match x and y limits if not diffs: lims = (np.min([sub_ax.get_xlim()[0], sub_ax.get_ylim()[0]]), np.max([sub_ax.get_xlim()[1], sub_ax.get_ylim()[1]])) sub_ax.set_xlim(lims) sub_ax.set_ylim(lims) for axis in ["x", "y"]: plot_util.set_interm_ticks(np.asarray(sub_ax), n_ticks=4, axis=axis, share=False, fontweight="bold", update_ticks=True) # plot lines lims = sub_ax.get_xlim() line_kwargs = { "ls": plot_helper_fcts.VDASH, "color": "k", "alpha": 0.2, "zorder": -15, "lw": 4, } if diffs: sub_ax.axhline(y=0, **line_kwargs) else: # identity line sub_ax.plot([lims[0], lims[1]], [lims[0], lims[1]], **line_kwargs) # shade in opposite quadrants sub_ax.fill_between([lims[0], 0], 0, lims[1], alpha=0.075, facecolor="k", edgecolor="none", zorder=-16) sub_ax.fill_between([0, lims[1]], lims[0], 0, alpha=0.075, facecolor="k", edgecolor="none", zorder=-16) # regression line x_line = lims y_line = np.array(x_line) * slope + intercept sub_ax.plot(x_line, y_line, ls=plot_helper_fcts.VDASH, color=col, alpha=0.8, lw=line_kwargs["lw"]) # get significance info and slope side = np.sign(regr_corr - rand_corr_med) sensitivity = misc_analys.get_sensitivity(permpar) sig_str = misc_analys.get_sig_symbol(p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"]) # write slope on subplot lim_range = lims[1] - lims[0] x_pos = lims[0] + lim_range * 0.97 y_pos = lims[0] + lim_range * 0.05 slope_str = f"{sig_str} slope: {slope:.2f}" sub_ax.text(x_pos, y_pos, slope_str, fontweight="bold", style="italic", fontsize=16, ha="right") return sig_str
def plot_corr_ex_data_histogram(sub_ax, idx_corr_norm_row, corr_name="1v2", col="k"): """ plot_corr_ex_data_histogram(sub_ax, idx_corr_norm_row) Plots example random correlation data in a histogram show how normalized residual correlations are calculated. Required args: - sub_ax (plt subplot): subplot - idx_corr_norm_df (pd.Series): dataframe series with the following columns, in addition to the basic sess_df columns: for a specific session comparison, e.g. 1v2 - {}v{}_corrs (float): unnormalized intersession ROI index correlations - {}v{}_norm_corrs (float): normalized intersession ROI index correlations - {}v{}_rand_corr_meds (float): median of randomized correlations - {}v{}_rand_corrs_binned (list): binned random unnormalized intersession ROI index correlations - {}v{}_rand_corrs_bin_edges (list): bins edges Optional args: - corr_name (str): session pair correlation name, used in series columns default: "1v2" - col (str): color for real data """ med = idx_corr_norm_row[f"{corr_name}_rand_corr_meds"] raw_corr = idx_corr_norm_row[f"{corr_name}_corrs"] norm_corr = idx_corr_norm_row[f"{corr_name}_norm_corrs"] binned_corrs = \ np.asarray(idx_corr_norm_row[f"{corr_name}_rand_corrs_binned"]) bin_edges = idx_corr_norm_row[f"{corr_name}_rand_corrs_bin_edges"] bin_edges = np.linspace(bin_edges[0], bin_edges[1], len(binned_corrs) + 1) sub_ax.hist(bin_edges[:-1], bin_edges, weights=binned_corrs, color="gray", alpha=0.45, density=True) # median line sub_ax.axvline(x=med, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5) # corr line sub_ax.axvline(x=raw_corr, ls=plot_helper_fcts.VDASH, c=col, lw=3.0, alpha=0.7) # adjust axes so that at least 1/5 of the graph is beyond the correlation value xlims = list(sub_ax.get_xlim()) if raw_corr < med: leave_space = np.absolute(np.diff([raw_corr, xlims[1]]))[0] / 3 xlims[0] = np.min([xlims[0], -1.08, leave_space]) edge = -1 else: leave_space = np.absolute(np.diff([raw_corr, xlims[0]]))[0] / 3 xlims[1] = np.max([xlims[1], 1.08, leave_space]) edge = 1 # edge line sub_ax.axvline(x=edge, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5) # shift limits sub_ax.set_xlim(xlims) ylims = list(sub_ax.get_ylim()) sub_ax.set_ylim(ylims[0], ylims[1] * 1.3) # set initial x ticks with more optimal interval n_ticks = 3 xticks = plot_util.rounded_lims(xlims) for i, bound in zip([0, 1], [-1, 1]): if np.absolute(xticks[i]) >= 1: xticks[i] = bound step = np.diff(xticks)[0] / (n_ticks + 1) o = math_util.get_order_of_mag(step) new_step = np.ceil(step / 10**o) * 10**o xticks[1 - i] = xticks[i] - new_step * (n_ticks + 1) * bound sub_ax.set_xticks(xticks) plot_util.set_interm_ticks(np.asarray(sub_ax), n_ticks=n_ticks, axis="x", share=False, fontweight="bold") sub_ax.set_ylabel("Density", fontweight="bold", labelpad=10) sub_ax.set_xlabel("Raw correlations", fontweight="bold") sub_ax.set_title(f"Normalized residual\ncorrelation: {norm_corr:.2f}", fontweight="bold", y=1.07)
def plot_rand_corr_ex_data(idx_corr_norm_df, title=None): """ plot_rand_corr_ex_data(idx_corr_norm_df) Plots example random correlation data in a scatterplot and histogram to show how normalized residual correlations are calculated. Required args: - idx_corr_norm_df (pd.DataFrame): dataframe with one row for a line/plane, and the following columns, in addition to the basic sess_df columns: for a specific session comparison, e.g. 1v2 - {}v{}_corrs (float): unnormalized intersession ROI index correlations - {}v{}_norm_corrs (float): normalized intersession ROI index correlations - {}v{}_rand_ex_corrs (float): unnormalized intersession ROI index correlations for an example of randomized data - {}v{}_rand_corr_meds (float): median of randomized correlations - {}v{}_corr_data (list): intersession values to correlate - {}v{}_rand_ex (list): intersession values for an example of randomized data - {}v{}_rand_corrs_binned (list): binned random unnormalized intersession ROI index correlations - {}v{}_rand_corrs_bin_edges (list): bins edges Optional args: - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ plot_types = ["scatter", "hist"] fig, ax = plt.subplots(nrows=len(plot_types), figsize=[8.7, 9.3], gridspec_kw={"hspace": 0.7}) if len(idx_corr_norm_df) != 1: raise ValueError("Expected idx_corr_norm_df to contain only one row.") sorted_pairs = get_sorted_sess_pairs(idx_corr_norm_df, norm=True) if len(sorted_pairs) != 1: raise RuntimeError( "Expected to find only one pair of sessions for which to plot data." ) sess_pair = sorted_pairs[0] row = idx_corr_norm_df.loc[idx_corr_norm_df.index[0]] corr_name = f"{sess_pair[0]}v{sess_pair[1]}" _, _, col, _ = plot_helper_fcts.get_line_plane_idxs( row["lines"], row["planes"]) if title is not None: fig.suptitle(title, y=0.95, weight="bold") # plot scatterplot scatt_ax = ax[0] plot_corr_ex_data_scatterplot(scatt_ax, row, corr_name=corr_name, col=col) # plot histogram hist_ax = ax[1] plot_corr_ex_data_histogram(hist_ax, row, corr_name=corr_name, col=col) plot_util.set_interm_ticks(ax, n_ticks=4, axis="y", share=False, fontweight="bold") return ax
def plot_perc_sig_usis(perc_sig_df, analyspar, permpar, figpar, by_mouse=False, title=None): """ plot_perc_sig_usis(perc_sig_df, analyspar, figpar) Plots percentage of significant USIs. Required args: - perc_sig_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: for sig in ["lo", "hi"]: for low vs high ROI indices - perc_sig_{sig}_idxs (num): percent significant ROIs (0-100) - perc_sig_{sig}_idxs_stds (num): bootstrapped standard deviation over percent significant ROIs - perc_sig_{sig}_idxs_CIs (list): adjusted CI for percent sig. ROIs - perc_sig_{sig}_idxs_null_CIs (list): adjusted null CI for percent sig. ROIs - perc_sig_{sig}_idxs_raw_p_vals (num): uncorrected p-value for percent sig. ROIs - perc_sig_{sig}_idxs_p_vals (num): p-value for percent sig. ROIs, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - by_mouse (bool): if True, plotting is done per mouse default: False - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ perc_sig_df = perc_sig_df.copy(deep=True) nanpol = None if analyspar["rem_bad"] else "omit" sess_ns = perc_sig_df["sess_ns"].unique() if len(sess_ns) != 1: raise NotImplementedError( "Plotting function implemented for 1 session only." ) figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=1, sharex=True, sharey=True) figpar["init"]["sharey"] = True figpar["init"]["subplot_wid"] = 3.4 figpar["init"]["gs"] = {"wspace": 0.18} if by_mouse: figpar["init"]["subplot_hei"] = 8.4 else: figpar["init"]["subplot_hei"] = 3.5 fig, ax = plot_util.init_fig(2, **figpar["init"]) if title is not None: y = 0.98 if by_mouse else 1.07 fig.suptitle(title, y=y, weight="bold") tail_order = ["Low tail", "High tail"] tail_keys = ["lo", "hi"] chance = permpar["p_val"] / 2 * 100 ylims = get_perc_sig_ylims(perc_sig_df, high_pt_min=40) n_linpla = plot_helper_fcts.N_LINPLA comp_info = misc_analys.get_comp_info(permpar) logger.info(f"{comp_info}:", extra={"spacing": "\n"}) for t, (tail, key) in enumerate(zip(tail_order, tail_keys)): sub_ax = ax[0, t] sub_ax.set_title(tail, fontweight="bold") sub_ax.set_ylim(ylims) # replace bottom spine with line at 0 sub_ax.spines['bottom'].set_visible(False) sub_ax.axhline(y=0, c="k", lw=4.0) data_key = f"perc_sig_{key}_idxs" CIs = np.full((plot_helper_fcts.N_LINPLA, 2), np.nan) CI_meds = np.full(plot_helper_fcts.N_LINPLA, np.nan) tail_sig_str = f"{tail:9}:" linpla_names = [] for (line, plane), lp_df in perc_sig_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) x_index = 2 * li + pl linpla_name = plot_helper_fcts.get_line_plane_name(line, plane) linpla_names.append(linpla_name) if len(lp_df) == 0: continue elif len(lp_df) > 1 and not by_mouse: raise RuntimeError("Expected a single row per line/plane.") lp_df = lp_df.sort_values("mouse_ns") # sort by mouse df_indices = lp_df.index.tolist() if by_mouse: # plot means or medians per mouse mouse_data = lp_df[data_key].to_numpy() mouse_cols = plot_util.get_hex_color_range( len(lp_df), col=col, interval=plot_helper_fcts.MOUSE_COL_INTERVAL ) mouse_data_mean = math_util.mean_med( mouse_data, stats=analyspar["stats"], nanpol=nanpol ) CI_dummy = np.repeat(mouse_data_mean, 2) plot_util.plot_CI(sub_ax, CI_dummy, med=mouse_data_mean, x=x_index, width=0.6, med_col=col, med_rat=0.01) else: # collect confidence interval data row = lp_df.loc[df_indices[0]] mouse_cols = [col] CIs[x_index] = np.asarray(row[f"{data_key}_null_CIs"])[ np.asarray([0, 2]) ] CI_meds[x_index] = row[f"{data_key}_null_CIs"][1] if by_mouse: perc_p_vals = [] rel_y = 0.05 else: tail_sig_str = f"{tail_sig_str}{TAB}{linpla_name}: " rel_y = 0.1 for df_i, mouse_col in zip(df_indices, mouse_cols): # plot UFOs err = None no_line = True if not by_mouse: err = perc_sig_df.loc[df_i, f"{data_key}_stds"] no_line = False # indicate bootstrapped error with wider capsize plot_util.plot_ufo( sub_ax, x_index, perc_sig_df.loc[df_i, data_key], err, color=mouse_col, capsize=8, no_line=no_line ) # add significance markers p_val = perc_sig_df.loc[df_i, f"{data_key}_p_vals"] perc = perc_sig_df.loc[df_i, data_key] nrois = np.sum(perc_sig_df.loc[df_i, "nrois"]) side = np.sign(perc - chance) sensitivity = misc_analys.get_binom_sensitivity( nrois, null_perc=chance, side=side ) sig_str = misc_analys.get_sig_symbol( p_val, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) if len(sig_str): perc_high = perc + err if err is not None else perc plot_util.add_signif_mark(sub_ax, x_index, perc_high, rel_y=rel_y, color=mouse_col, fontsize=24, mark=sig_str) if by_mouse: perc_p_vals.append( (int(np.around(perc)), p_val, sig_str) ) else: tail_sig_str = ( f"{tail_sig_str}{p_val:.5f}{sig_str:3}" ) if by_mouse: # sort p-value logging by percentage value tail_sig_str = f"{tail_sig_str}\n\t{linpla_name:6}: " order = np.argsort([vals[0] for vals in perc_p_vals]) for i in order: perc, p_val, sig_str = perc_p_vals[i] perc_str = f"(~{perc}%)" tail_sig_str = ( f"{tail_sig_str}{TAB}{perc_str:6} " f"{p_val:.5f}{sig_str:3}" ) # add chance information if by_mouse: sub_ax.axhline( y=chance, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5, zorder=-12 ) else: plot_util.plot_CI(sub_ax, CIs.T, med=CI_meds, x=np.arange(n_linpla), width=0.45, med_rat=0.025, zorder=-12) logger.info(tail_sig_str, extra={"spacing": TAB}) for sub_ax in fig.axes: sub_ax.tick_params(axis="x", which="both", bottom=False) plot_util.set_ticks( sub_ax, min_tick=0, max_tick=n_linpla - 1, n=n_linpla, pad_p=0.2) sub_ax.set_xticklabels(linpla_names, rotation=90, weight="bold") ax[0, 0].set_ylabel("%", fontweight="bold") plot_util.set_interm_ticks(ax, 3, axis="y", weight="bold", share=True) # adjustment if tick interval is repeated in the negative if ax[0, 0].get_ylim()[0] < 0: ax[0, 0].set_ylim([ylims[0], ax[0, 0].get_ylim()[1]]) return ax
def plot_idxs(idx_df, sesspar, figpar, plot="items", density=True, n_bins=40, title=None, size="reg"): """ plot_idxs(idx_df, sesspar, figpar) Returns exact color for a specific line. Required args: - idx_df (pd.DataFrame): dataframe with indices for different line/plane combinations, and the following columns, in addition to the basic sess_df columns: - rand_idx_binned (list): bin counts for the random ROI indices - bin_edges (list): first and last bin edge - CI_edges (list): confidence interval limit values - CI_perc (list): confidence interval percentile limits if plot == "items": - roi_idx_binned (list): bin counts for the ROI indices if plot == "percs": - perc_idx_binned (list): bin counts for the ROI index percentiles optionally: - n_signif_lo (int): number of significant ROIs (low) - n_signif_hi (int): number of significant ROIs (high) - sesspar (dict): dictionary with keys of SessPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - plot (str): type of data to plot ("items" or "percs") default: "items" - density (bool): if True, histograms are plotted as densities default: True - n_bins (int): number of bins to use in histograms default: 40 - title (str): plot title default: None - size (str): plot size ("reg", "small" or "tall") default: "reg" Returns: - ax (2D array): array of subplots """ if plot == "items": data_key = "roi_idx_binned" CI_key = "CI_edges" elif plot == "percs": data_key = "perc_idx_binned" CI_key = "CI_perc" else: gen_util.accepted_values_error("plot", plot, ["items", "percs"]) sess_ns = misc_analys.get_sess_ns(sesspar, idx_df) n_plots = len(sess_ns) * 4 figpar["init"]["sharey"] = "row" figpar = sess_plot_util.fig_init_linpla(figpar, kind="idx", n_sub=len(sess_ns), sharex=(plot == "percs")) y = 1 if size == "reg": subplot_hei = 3.2 subplot_wid = 5.5 elif size == "small": y = 1.04 subplot_hei = 2.40 subplot_wid = 3.75 figpar["init"]["gs"] = {"hspace": 0.25, "wspace": 0.30} elif size == "tall": y = 0.98 subplot_hei = 5.3 subplot_wid = 5.55 else: gen_util.accepted_values_error("size", size, ["reg", "small", "tall"]) figpar["init"]["subplot_hei"] = subplot_hei figpar["init"]["subplot_wid"] = subplot_wid figpar["init"]["sharey"] = "row" fig, ax = plot_util.init_fig(n_plots, **figpar["init"]) if title is not None: fig.suptitle(title, y=y, weight="bold") for (line, plane), lp_df in idx_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) for s, sess_n in enumerate(sess_ns): rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(rows) == 0: continue elif len(rows) > 1: raise RuntimeError( "Expected sess_ns to be unique per line/plane." ) row = rows.loc[rows.index[0]] sub_ax = ax[s + pl * len(sess_ns), li] # get percentage significant label perc_label = None if "n_signif_lo" in row.keys() and "n_signif_hi" in row.keys(): n_sig_lo, n_sig_hi = row["n_signif_lo"], row["n_signif_hi"] nrois = np.sum(row["nrois"]) perc_signif = np.sum([n_sig_lo, n_sig_hi]) / nrois * 100 perc_label = (f"{perc_signif:.2f}% sig\n" f"({n_sig_lo}-/{n_sig_hi}+ of {nrois})") plot_stim_idx_hist( sub_ax, row[data_key], row[CI_key], n_bins=n_bins, rand_data=row["rand_idx_binned"], orig_edges=row["bin_edges"], plot=plot, col=col, density=density, perc_label=perc_label) if size == "small": sub_ax.legend(fontsize="small") # Add plane, line info to plots y_lab = "Density" if density else f"N ROIs" sess_plot_util.format_linpla_subaxes(ax, datatype="roi", ylab=y_lab, xticks=None, sess_ns=None, kind="idx", modif_share=False, xlab="Index", single_lab=True) # Add indices after setting formatting if plot == "percs": nticks = 5 xticks = [int(np.around(x, 0)) for x in np.linspace(0, 100, nticks)] for sub_ax in ax[-1]: sub_ax.set_xticks(xticks) sub_ax.set_xticklabels(xticks, weight="bold") elif plot == "items": nticks = 3 plot_util.set_interm_ticks( ax, nticks, axis="x", weight="bold", share=False, skip=False ) else: gen_util.accepted_values_error("plot", plot, ["items", "percs"]) return ax
def plot_roi_correlations(corr_df, figpar, title=None, log_scale=True): """ plot_roi_correlations(corr_df, figpar) Plots correlation histograms. Required args: - corr_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - bin_edges (list): first and last bin edge - corrs_binned (list): number of correlation values per bin - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None - log_scale (bool): if True, a near logarithmic scale is used for the y axis (with a linear range to reach 0, and break marks to mark the transition from linear to log range) default: True Returns: - ax (2D array): array of subplots """ sess_ns = np.arange(corr_df.sess_ns.min(), corr_df.sess_ns.max() + 1) n_sess = len(sess_ns) figpar = sess_plot_util.fig_init_linpla(figpar, kind="prog", n_sub=len(sess_ns)) figpar["init"]["subplot_hei"] = 3.0 figpar["init"]["subplot_wid"] = 2.8 figpar["init"]["sharex"] = log_scale if log_scale: figpar["init"]["sharey"] = True fig, ax = plot_util.init_fig(4 * len(sess_ns), **figpar["init"]) if title is not None: fig.suptitle(title, y=1.02, weight="bold") sess_plot_util.format_linpla_subaxes(ax, datatype="roi", ylab="Density", xlab="Correlation", sess_ns=sess_ns, kind="prog", single_lab=True) log_base = 2 for (line, plane), lp_df in corr_df.groupby(["lines", "planes"]): li, pl, col, _ = plot_helper_fcts.get_line_plane_idxs(line, plane) for s, sess_n in enumerate(sess_ns): sess_rows = lp_df.loc[lp_df["sess_ns"] == sess_n] if len(sess_rows) == 0: continue elif len(sess_rows) > 1: raise RuntimeError("Expected exactly one row.") sess_row = sess_rows.loc[sess_rows.index[0]] sub_ax = ax[pl, s + li * n_sess] weights = np.asarray(sess_row["corrs_binned"]) bin_edges = np.linspace(*sess_row["bin_edges"], len(weights) + 1) sub_ax.hist(bin_edges[:-1], bin_edges, weights=weights, color=col, alpha=0.6, density=True) sub_ax.axvline(0, ls=plot_helper_fcts.VDASH, c="k", lw=3.0, alpha=0.5) sub_ax.spines["bottom"].set_visible(True) sub_ax.tick_params(axis="x", which="both", bottom=True, top=False) if log_scale: sub_ax.set_yscale("log", base=log_base) sub_ax.set_xlim(-1, 1) else: sub_ax.autoscale(axis="x", tight=True) sub_ax.autoscale(axis="y", tight=True) if log_scale: # update x ticks set_symlog_scale(ax, log_base=log_base, col_per_grp=n_sess, n_ticks=4) else: # update x and y ticks for i in range(ax.shape[0]): for j in range(int(ax.shape[1] / n_sess)): sub_axes = ax[i, j * n_sess:(j + 1) * n_sess] plot_util.set_interm_ticks(sub_axes, 4, axis="y", share=True, update_ticks=True) plot_util.set_interm_ticks(ax, 4, axis="x", share=log_scale, update_ticks=True, fontweight="bold") return ax
def plot_pupil_run_trace_stats(trace_df, analyspar, figpar, split="by_exp", title=None): """ plot_pupil_run_trace_stats(trace_df, analyspar, figpar) Plots pupil and running trace statistics. Required args: - trace_df (pd.DataFrame): dataframe with one row per session number, and the following columns, in addition to the basic sess_df columns: - run_trace_stats (list): running velocity trace stats (split x frames x stats (me, err)) - run_time_values (list): values for each frame, in seconds (only 0 to stimpar.post, unless split is "by_exp") - pupil_trace_stats (list): pupil diameter trace stats (split x frames x stats (me, err)) - pupil_time_values (list): values for each frame, in seconds (only 0 to stimpar.post, unless split is "by_exp") - analyspar (dict): dictionary with keys of AnalysPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - split (str): data split, e.g. "exp_lock", "unexp_lock", "stim_onset" or "stim_offset" default: False - title (str): plot title default: None Returns: - ax (2D array): array of subplots """ if split != "by_exp": raise NotImplementedError("Only implemented split 'by_exp'.") if analyspar["scale"]: raise NotImplementedError( "Expected running and pupil data to not be scaled." ) datatypes = ["run", "pupil"] figpar["init"]["subplot_wid"] = 4.2 figpar["init"]["subplot_hei"] = 2.2 figpar["init"]["gs"] = {"hspace": 0.3} figpar["init"]["ncols"] = 1 figpar["init"]["sharey"] = False figpar["init"]["sharex"] = True fig, ax = plot_util.init_fig(len(datatypes), **figpar["init"]) if title is not None: fig.suptitle(title, weight="bold", y=1.0) if len(trace_df) != 1: raise NotImplementedError( "Only implemented for a trace_df with one row." ) row_idx = trace_df.index[0] exp_col = plot_util.LINCLAB_COLS["gray"] unexp_col = plot_util.LINCLAB_COLS["red"] for d, datatype in enumerate(datatypes): sub_ax = ax[d, 0] time_values = trace_df.loc[row_idx, f"{datatype}_time_values"] trace_stats = trace_df.loc[row_idx, f"{datatype}_trace_stats"] seq_plots.plot_traces( sub_ax, time_values, trace_stats, split=split, col=unexp_col, lab=False, exp_col=exp_col, hline=False ) if datatype == "run": ylabel = "Running\nvelocity\n(cm/s)" elif datatype == "pupil": ylabel = "Pupil\ndiameter\n(mm)" sub_ax.set_ylabel(ylabel, weight="bold") # fix x ticks and lims plot_util.set_interm_ticks(ax, 3, axis="x", fontweight="bold") xlims = [np.min(time_values), np.max(time_values)] if split != "by_exp": xlims = [-xlims[1], xlims[1]] sub_ax.set_xlim(xlims) sub_ax.set_xlabel("Time (s)", weight="bold") # expand y lims a bit and fix y ticks for sub_ax in ax.reshape(-1): plot_util.expand_lims(sub_ax, axis="y", prop=0.21) plot_util.set_interm_ticks( ax, 2, axis="y", share=False, weight="bold", update_ticks=True ) return ax
def plot_pupil_run_block_diffs(block_df, analyspar, permpar, figpar, title=None, seed=None): """ plot_pupil_run_trace_stats(trace_df, analyspar, permpar, figpar) Plots pupil and running block differences. Required args: - block_df (pd.DataFrame): dataframe with one row per session/line/plane, and the following columns, in addition to the basic sess_df columns: - run_block_diffs (list): running velocity differences per block - run_raw_p_vals (float): uncorrected p-value for differences within sessions - run_p_vals (float): p-value for differences within sessions, corrected for multiple comparisons and tails - pupil_block_diffs (list): for pupil diameter differences per block - pupil_raw_p_vals (list): uncorrected p-value for differences within sessions - pupil_p_vals (list): p-value for differences within sessions, corrected for multiple comparisons and tails - analyspar (dict): dictionary with keys of AnalysPar namedtuple - permpar (dict): dictionary with keys of PermPar namedtuple - figpar (dict): dictionary containing the following figure parameter dictionaries ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters Optional args: - title (str): plot title default: None - seed (int): seed value to use. (-1 treated as None) default: None Returns: - ax (2D array): array of subplots """ if analyspar["scale"]: raise NotImplementedError( "Expected running and pupil data to not be scaled." ) if len(block_df["sess_ns"].unique()) != 1: raise NotImplementedError( "'block_df' should only contain one session number." ) nanpol = None if analyspar["rem_bad"] else "omit" sensitivity = misc_analys.get_sensitivity(permpar) comp_info = misc_analys.get_comp_info(permpar) datatypes = ["run", "pupil"] datatype_strs = ["Running velocity", "Pupil diameter"] n_datatypes = len(datatypes) fig, ax = plt.subplots( 1, n_datatypes, figsize=(12.7, 4), squeeze=False, gridspec_kw={"wspace": 0.22} ) if title is not None: fig.suptitle(title, y=1.2, weight="bold") logger.info(f"{comp_info}:", extra={"spacing": "\n"}) corr_str = "corr." if permpar["multcomp"] else "raw" for d, datatype in enumerate(datatypes): datatype_sig_str = f"{datatype_strs[d]:16}:" sub_ax = ax[0, d] lp_names = [None for _ in range(plot_helper_fcts.N_LINPLA)] xs, all_data, cols, dashes, p_val_texts = [], [], [], [], [] for (line, plane), lp_df in block_df.groupby(["lines", "planes"]): x, col, dash = plot_helper_fcts.get_line_plane_idxs( line, plane, flat=True ) line_plane_name = plot_helper_fcts.get_line_plane_name( line, plane ) lp_names[int(x)] = line_plane_name if len(lp_df) == 1: row_idx = lp_df.index[0] elif len(lp_df) > 1: raise RuntimeError("Expected 1 row per line/plane/session.") lp_data = lp_df.loc[row_idx, f"{datatype}_block_diffs"] # get p-value information p_val_corr = lp_df.loc[row_idx, f"{datatype}_p_vals"] side = np.sign( math_util.mean_med( lp_data, stats=analyspar["stats"], nanpol=nanpol ) ) sig_str = misc_analys.get_sig_symbol( p_val_corr, sensitivity=sensitivity, side=side, tails=permpar["tails"], p_thresh=permpar["p_val"] ) p_val_text = f"{p_val_corr:.2f}{sig_str}" datatype_sig_str = ( f"{datatype_sig_str}{TAB}{line_plane_name}: " f"{p_val_corr:.5f}{sig_str:3}" ) # collect information xs.append(x) all_data.append(lp_data) cols.append(col) dashes.append(dash) p_val_texts.append(p_val_text) plot_violin_data( sub_ax, xs, all_data, palette=cols, dashes=dashes, seed=seed ) # edit ticks sub_ax.set_xticks(range(plot_helper_fcts.N_LINPLA)) sub_ax.set_xticklabels(lp_names, fontweight="bold") sub_ax.tick_params(axis="x", which="both", bottom=False) plot_util.expand_lims(sub_ax, axis="y", prop=0.1) plot_util.set_interm_ticks( np.asarray(sub_ax), n_ticks=3, axis="y", share=False, fontweight="bold", update_ticks=True ) for i, (x, p_val_text) in enumerate(zip(xs, p_val_texts)): ylim_range = np.diff(sub_ax.get_ylim()) y = sub_ax.get_ylim()[1] + ylim_range * 0.08 ha = "center" if d == 0 and i == 0: x += 0.2 p_val_text = f"{corr_str} p-val. {p_val_text}" ha = "right" sub_ax.text( x, y, p_val_text, fontsize=20, weight="bold", ha=ha ) logger.info(datatype_sig_str, extra={"spacing": TAB}) # add labels/titles for d, datatype in enumerate(datatypes): sub_ax = ax[0, d] sub_ax.axhline( y=0, ls=plot_helper_fcts.HDASH, c="k", lw=3.0, alpha=0.5 ) if d == 0: ylabel = "Trial differences\nU-G - D-G" sub_ax.set_ylabel(ylabel, weight="bold") if datatype == "run": title = "Running velocity (cm/s)" elif datatype == "pupil": title = "Pupil diameter (mm)" sub_ax.set_title(title, weight="bold", y=1.2) return ax