def plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, sess_info, autocorr_data, figpar=None, savedir=None): """ plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, sess_info, autocorr_data) From dictionaries, plots autocorrelation during stimulus blocks. Required args: - analyspar (dict) : dictionary with keys of AnalysPar namedtuple - sesspar (dict) : dictionary with keys of SessPar namedtuple - stimpar (dict) : dictionary with keys of StimPar namedtuple - extrapar (dict) : dictionary containing additional analysis parameters ["analysis"] (str): analysis type (e.g., "a") ["datatype"] (str): datatype (e.g., "run", "roi") - autocorrpar (dict) : dictionary with keys of AutocorrPar namedtuple - sess_info (dict) : dictionary containing information from each session ["mouse_ns"] (list) : mouse numbers ["sess_ns"] (list) : session numbers ["lines"] (list) : mouse lines ["planes"] (list) : imaging planes ["nrois"] (list) : number of ROIs in session - autocorr_data (dict): dictionary containing data to plot: ["xrans"] (list): list of lag values in seconds for each session ["stats"] (list): list of 3D arrays (or nested lists) of autocorrelation statistics, structured as: sessions stats (me, err) x ROI or 1x and 10x lag x lag Optional args: - figpar (dict): dictionary containing the following figure parameter dictionaries default: None ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters - savedir (str): path of directory in which to save plots. default: None Returns: - fulldir (str) : final path of the directory in which the figure is saved (may differ from input savedir, if datetime subfolder is added.) - savename (str): name under which the figure is saved """ statstr_pr = sess_str_util.stat_par_str( analyspar["stats"], analyspar["error"], "print") stimstr_pr = sess_str_util.stim_par_str( stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"], "print") dendstr_pr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print") sessstr = sess_str_util.sess_par_str( sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) dendstr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"]) datatype = extrapar["datatype"] if datatype == "roi": fluorstr_pr = sess_str_util.fluor_par_str( analyspar["fluor"], str_type="print") if autocorrpar["byitem"]: title_str = u"{}\nautocorrelation".format(fluorstr_pr) else: title_str = "\nautocorr. acr. ROIs" .format(fluorstr_pr) elif datatype == "run": datastr = sess_str_util.datatype_par_str(datatype) title_str = u"\n{} autocorrelation".format(datastr) if stimpar["stimtype"] == "gabors": seq_bars = [-1.5, 1.5] # light lines else: seq_bars = [-1.0, 1.0] # light lines # extract some info from sess_info keys = ["mouse_ns", "sess_ns", "lines", "planes"] [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys] nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) n_sess = len(mouse_ns) xrans = autocorr_data["xrans"] stats = [np.asarray(stat) for stat in autocorr_data["stats"]] lag_s = autocorrpar["lag_s"] xticks = np.linspace(-lag_s, lag_s, lag_s*2+1) yticks = np.linspace(0, 1, 6) if figpar is None: figpar = sess_plot_util.init_figpar() byitemstr = "" if autocorrpar["byitem"]: byitemstr = "_byroi" fig, ax = plot_util.init_fig(n_sess, **figpar["init"]) for i in range(n_sess): sub_ax = plot_util.get_subax(ax, i) title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, " u"{} ".format(statstr_pr) + f"{title_str} (sess " f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}{nroi_strs[i]})") # transpose to ROI/lag x stats x series sess_stats = stats[i].transpose(1, 0, 2) for s, sub_stats in enumerate(sess_stats): lab = None if not autocorrpar["byitem"]: lab = ["actual lag", "10x lag"][s] plot_util.plot_traces( sub_ax, xrans[i], sub_stats[0], sub_stats[1:], xticks=xticks, yticks=yticks, alpha=0.2, label=lab) plot_util.add_bars(sub_ax, hbars=seq_bars) sub_ax.set_ylim([0, 1]) sub_ax.set_title(title, y=1.02) if sub_ax.is_last_row(): sub_ax.set_xlabel("Lag (s)") plot_util.turn_off_extra(ax, n_sess) if savedir is None: savedir = Path( figpar["dirs"][datatype], figpar["dirs"]["autocorr"]) savename = (f"{datatype}_autocorr{byitemstr}_{sessstr}{dendstr}") fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"]) return fulldir, savename
def plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, quantpar, sess_info, trace_stats, figpar=None, savedir=None, modif=False): """ plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, quantpar, sess_info, trace_stats) From dictionaries, plots traces by quantile, locked to transitions from unexpected to expected or v.v. with each session in a separate subplot. Returns figure name and save directory path. Required args: - analyspar (dict) : dictionary with keys of AnalysPar namedtuple - sesspar (dict) : dictionary with keys of SessPar namedtuple - stimpar (dict) : dictionary with keys of StimPar namedtuple - extrapar (dict) : dictionary containing additional analysis parameters ["analysis"] (str): analysis type (e.g., "l") ["datatype"] (str): datatype (e.g., "run", "roi") - quantpar (dict) : dictionary with keys of QuantPar namedtuple - sess_info (dict) : dictionary containing information from each session ["mouse_ns"] (list) : mouse numbers ["sess_ns"] (list) : session numbers ["lines"] (list) : mouse lines ["planes"] (list) : imaging planes if datatype == ["nrois"] (list) : number of ROIs in session - trace_stats (dict): dictionary containing trace stats information ["xrans"] (list) : time values for the 2p frames for each session ["all_stats"] (list) : list of 4D arrays or lists of trace data statistics across ROIs for each session, structured as: (unexp_len x) quantiles x stats (me, err) x frames ["all_counts"] (array-like): number of sequences, structured as: sess x (unexp_len x) quantiles ["lock"] (str) : value to which segments are locked: "unexp", "exp" or "unexp_split" ["baseline"] (num) : number of seconds used for baseline ["exp_stats"] (list) : list of 3D arrays or lists of trace data statistics across ROIs for expected sampled sequences, structured as: quantiles (1) x stats (me, err) x frames ["exp_counts"] (array-like): number of sequences corresponding to exp_stats, structured as: sess x quantiles (1) if data is by unexp_len: ["unexp_lens"] (list) : number of consecutive segments for each unexp_len, structured by session Optional args: - figpar (dict): dictionary containing the following figure parameter dictionaries default: None ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters - savedir (str): path of directory in which to save plots. default: None - modif (bool) : if True, modified (slimmed-down) plots are created instead default: False Returns: - fulldir (str) : final path of the directory in which the figure is saved (may differ from input savedir, if datetime subfolder is added.) - savename (str): name under which the figure is saved """ analyspar["dend"] = None stimstr_pr = sess_str_util.stim_par_str( stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"], "print") statstr_pr = sess_str_util.stat_par_str( analyspar["stats"], analyspar["error"], "print") dendstr_pr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print") sessstr = sess_str_util.sess_par_str( sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"]) dendstr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"]) basestr = sess_str_util.base_par_str(trace_stats["baseline"]) basestr_pr = sess_str_util.base_par_str(trace_stats["baseline"], "print") datatype = extrapar["datatype"] dimstr = sess_str_util.datatype_dim_str(datatype) # extract some info from sess_info keys = ["mouse_ns", "sess_ns", "lines", "planes"] [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys] nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) n_sess = len(mouse_ns) xrans = [np.asarray(xran) for xran in trace_stats["xrans"]] all_stats = [np.asarray(sessst) for sessst in trace_stats["all_stats"]] exp_stats = [np.asarray(expst) for expst in trace_stats["exp_stats"]] all_counts = trace_stats["all_counts"] exp_counts = trace_stats["exp_counts"] lock = trace_stats["lock"] col_idx = 0 if "unexp" in lock: lock = "unexp" col_idx = 1 # plot unexp_lens default values if stimpar["stimtype"] == "gabors": DEFAULT_UNEXP_LEN = [3.0, 4.5, 6.0] if stimpar["gabfr"] not in ["any", "all"]: offset = sess_str_util.gabfr_nbrs(stimpar["gabfr"]) else: DEFAULT_UNEXP_LEN = [2.0, 3.0, 4.0] offset = 0 unexp_lab, len_ext = "", "" unexp_lens = [[None]] * n_sess unexp_len_default = True if "unexp_lens" in trace_stats.keys(): unexp_len_default = False unexp_lens = trace_stats["unexp_lens"] len_ext = "_bylen" if stimpar["stimtype"] == "gabors": unexp_lens = [ [sl * 1.5/5 - 0.3 * offset for sl in sls] for sls in unexp_lens ] inv = 1 if lock == "unexp" else -1 # RANGE TO PLOT if modif: st_val = -2.0 end_val = 6.0 n_ticks = int((end_val - st_val) // 2 + 1) else: n_ticks = 21 if figpar is None: figpar = sess_plot_util.init_figpar() figpar = copy.deepcopy(figpar) if modif: figpar["init"]["subplot_wid"] = 6.5 else: figpar["init"]["subplot_wid"] *= 2 fig, ax = plot_util.init_fig(n_sess, **figpar["init"]) exp_min, exp_max = np.inf, -np.inf for i, (stats, counts) in enumerate(zip(all_stats, all_counts)): sub_ax = plot_util.get_subax(ax, i) # plot expected data if exp_stats[i].shape[0] != 1: raise ValueError("Expected only one quantile for exp_stats.") n_lines = quantpar["n_quants"] * len(unexp_lens[i]) cols = sess_plot_util.get_quant_cols(n_lines)[0][col_idx] if len(cols) < n_lines: cols = [None] * n_lines if modif: line = "2/3" if "23" in lines[i] else "5" plane = "somata" if "soma" in planes[i] else "dendrites" title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}" lab = "exp" if i == 0 else None y_ax = None if i == 0 else "" st, end = 0, len(xrans[i]) st_vals = list(filter( lambda j: xrans[i][j] <= st_val, range(len(xrans[i])) )) end_vals = list(filter( lambda j: xrans[i][j] >= end_val, range(len(xrans[i])) )) if len(st_vals) != 0: st = st_vals[-1] if len(end_vals) != 0: end = end_vals[0] + 1 time_slice = slice(st, end) else: title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, " u"{} ".format(statstr_pr) + f"{lock} locked across {dimstr}" f"{basestr_pr}\n(sess {sess_ns[i]}, {lines[i]} {planes[i]}" f"{dendstr_pr}{nroi_strs[i]})") lab = f"exp (no lock) ({exp_counts[i][0]})" y_ax = None st = 0 end = len(xrans[i]) time_slice = slice(None) # use all # add length markers use_unexp_lens = unexp_lens[i] if unexp_len_default: use_unexp_lens = DEFAULT_UNEXP_LEN leng_col = sess_plot_util.get_quant_cols(1)[0][col_idx][0] for leng in use_unexp_lens: if leng is None: continue edge = leng * inv if edge < 0: edge = np.max([xrans[i][st], edge]) elif edge > 0: edge = np.min([xrans[i][end - 1], edge]) plot_util.add_vshade( sub_ax, 0, edge, color=leng_col, alpha=0.1) sess_plot_util.add_axislabels( sub_ax, fluor=analyspar["fluor"], datatype=datatype, y_ax=y_ax ) plot_util.add_bars(sub_ax, hbars=0) alpha = np.min([0.4, 0.8 / n_lines]) if stimpar["stimtype"] == "gabors": sess_plot_util.plot_gabfr_pattern( sub_ax, xrans[i], offset=offset, bars_omit=[0] + unexp_lens[i] ) plot_util.plot_traces( sub_ax, xrans[i][time_slice], exp_stats[i][0][0, time_slice], exp_stats[i][0][1:, time_slice], n_xticks=n_ticks, alpha=alpha, label=lab, alpha_line=0.8, color="darkgray", xticks="auto") # get expected data range to adjust y lims exp_min = np.min([exp_min, np.nanmin(exp_stats[i][0][0])]) exp_max = np.max([exp_max, np.nanmax(exp_stats[i][0][0])]) n = 0 # count lines plotted for s, unexp_len in enumerate(unexp_lens[i]): if unexp_len is not None: counts, stats = all_counts[i][s], all_stats[i][s] # remove offset unexp_lab = f"unexp len {unexp_len + 0.3 * offset}" else: unexp_lab = "unexp" if modif else f"{lock} lock" for q, qu_idx in enumerate(quantpar["qu_idx"]): qu_lab = "" if quantpar["n_quants"] > 1: qu_lab = "{} ".format(sess_str_util.quantile_str( qu_idx, quantpar["n_quants"], str_type="print" )) lab = f"{qu_lab}{unexp_lab}" if modif: lab = lab if i == 0 else None else: lab = f"{lab} ({counts[q]})" if n == 2 and cols[n] is None: sub_ax.plot([], []) # to advance the color cycle (past gray) plot_util.plot_traces(sub_ax, xrans[i][time_slice], stats[q][0, time_slice], stats[q][1:, time_slice], title, alpha=alpha, label=lab, n_xticks=n_ticks, alpha_line=0.8, color=cols[n], xticks="auto") n += 1 if unexp_len is not None: plot_util.add_bars( sub_ax, hbars=unexp_len, color=sub_ax.lines[-1].get_color(), alpha=1) plot_util.turn_off_extra(ax, n_sess) if savedir is None: savedir = Path( figpar["dirs"][datatype], figpar["dirs"]["unexp_qu"], f"{lock}_lock", basestr.replace("_", "")) if not modif: if stimpar["stimtype"] == "visflow": plot_util.rel_confine_ylims(sub_ax, [exp_min, exp_max], 5) qu_str = f"_{quantpar['n_quants']}q" if quantpar["n_quants"] == 1: qu_str = "" savename = (f"{datatype}_av_{lock}_lock{len_ext}{basestr}_{sessstr}" f"{dendstr}{qu_str}") fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"]) return fulldir, savename
def plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, quantpar, sess_info, trace_stats, figpar=None, savedir=None, modif=False): """ plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, quantpar, sess_info, trace_stats) From dictionaries, plots traces by quantile/unexpected with each session in a separate subplot. Returns figure name and save directory path. Required args: - analyspar (dict) : dictionary with keys of AnalysPar namedtuple - sesspar (dict) : dictionary with keys of SessPar namedtuple - stimpar (dict) : dictionary with keys of StimPar namedtuple - extrapar (dict) : dictionary containing additional analysis parameters ["analysis"] (str): analysis type (e.g., "t") ["datatype"] (str): datatype (e.g., "run", "roi") - quantpar (dict) : dictionary with keys of QuantPar namedtuple - sess_info (dict) : dictionary containing information from each session ["mouse_ns"] (list) : mouse numbers ["sess_ns"] (list) : session numbers ["lines"] (list) : mouse lines ["planes"] (list) : imaging planes if extrapar["datatype"] == "roi": ["nrois"] (list) : number of ROIs in session - trace_stats (dict): dictionary containing trace stats information ["xrans"] (list) : time values for the frames, for each session ["all_stats"] (list) : list of 4D arrays or lists of trace data statistics across ROIs for each session, structured as: sess x unexp x quantiles x stats (me, err) x frames ["all_counts"] (array-like): number of sequences, structured as: sess x unexp x quantiles Optional args: - figpar (dict): dictionary containing the following figure parameter dictionaries default: None ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters - savedir (str): path of directory in which to save plots. default: None - modif (bool) : if True, modified (slimmed-down) plots are created instead default: False Returns: - fulldir (str) : final path of the directory in which the figure is saved (may differ from input savedir, if datetime subfolder is added.) - savename (str): name under which the figure is saved """ stimstr_pr = sess_str_util.stim_par_str( stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"], "print") statstr_pr = sess_str_util.stat_par_str( analyspar["stats"], analyspar["error"], "print") dendstr_pr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print") sessstr = sess_str_util.sess_par_str( sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"]) dendstr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"]) datatype = extrapar["datatype"] dimstr = sess_str_util.datatype_dim_str(datatype) # extract some info from sess_info keys = ["mouse_ns", "sess_ns", "lines", "planes"] [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys] nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) n_sess = len(mouse_ns) xrans = [np.asarray(xran) for xran in trace_stats["xrans"]] all_stats = [np.asarray(sessst) for sessst in trace_stats["all_stats"]] all_counts = trace_stats["all_counts"] cols, lab_cols = sess_plot_util.get_quant_cols(quantpar["n_quants"]) alpha = np.min([0.4, 0.8 / quantpar["n_quants"]]) unexps = ["exp", "unexp"] n = 6 if figpar is None: figpar = sess_plot_util.init_figpar() fig, ax = plot_util.init_fig(n_sess, **figpar["init"]) for i in range(n_sess): sub_ax = plot_util.get_subax(ax, i) for s, [col, leg_ext] in enumerate(zip(cols, unexps)): for q, qu_idx in enumerate(quantpar["qu_idx"]): qu_lab = "" if quantpar["n_quants"] > 1: qu_lab = "{} ".format(sess_str_util.quantile_str( qu_idx, quantpar["n_quants"], str_type="print" )) if modif: line = "2/3" if "23" in lines[i] else "5" plane = "somata" if "soma" in planes[i] else "dendrites" title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}" leg = f"{qu_lab}{leg_ext}" if i == 0 else None y_ax = None if i == 0 else "" else: title=(f"Mouse {mouse_ns[i]} - {stimstr_pr}, " u"{}\n".format(statstr_pr) + f"across {dimstr} (sess " f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}" f"{nroi_strs[i]})") leg = f"{qu_lab}{leg_ext} ({all_counts[i][s][q]})" y_ax = None plot_util.plot_traces( sub_ax, xrans[i], all_stats[i][s, q, 0], all_stats[i][s, q, 1:], title, color=col[q], alpha=alpha, label=leg, n_xticks=n, xticks="auto") sess_plot_util.add_axislabels( sub_ax, fluor=analyspar["fluor"], datatype=datatype, y_ax=y_ax) plot_util.turn_off_extra(ax, n_sess) if stimpar["stimtype"] == "gabors": sess_plot_util.plot_labels( ax, stimpar["gabfr"], "both", pre=stimpar["pre"], post=stimpar["post"], cols=lab_cols, sharey=figpar["init"]["sharey"]) if savedir is None: savedir = Path( figpar["dirs"][datatype], figpar["dirs"]["unexp_qu"]) qu_str = f"_{quantpar['n_quants']}q" if quantpar["n_quants"] == 1: qu_str = "" savename = f"{datatype}_av_{sessstr}{dendstr}{qu_str}" fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"]) return fulldir, savename
def plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, sess_info, corr_data, figpar=None, savedir=None): """ plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, sess_info, corr_data) From dictionaries, plots correlation between unexpected-locked changes in pupil diameter and each ROI, for gabors versus visual flow responses for each session. Required args: - analyspar (dict) : dictionary with keys of AnalysPar namedtuple - sesspar (dict) : dictionary with keys of SessPar namedtuple - stimpar (dict) : dictionary with keys of StimPar namedtuple - extrapar (dict) : dictionary containing additional analysis parameters ["analysis"] (str): analysis type (e.g., "r") ["datatype"] (str): datatype (e.g., "roi") - sess_info (dict) : dictionary containing information from each session ["mouse_ns"] (list) : mouse numbers ["sess_ns"] (list) : session numbers ["lines"] (list) : mouse lines ["planes"] (list) : imaging planes ["nrois"] (list) : number of ROIs in session - corr_data (dict) : dictionary containing data to plot: ["stim_order"] (list): ordered list of stimtypes ["roi_corrs"] (list) : nested list of correlations between pupil and ROI responses changes locked to unexpected, structured as session x stimtype x ROI ["corrs"] (list) : list of correlation between stimtype correlations for each session Optional args: - figpar (dict) : dictionary containing the following figure parameter dictionaries default: None ["init"] (dict): dictionary with figure initialization parameters ["save"] (dict): dictionary with figure saving parameters ["dirs"] (dict): dictionary with additional figure parameters - savedir (Path): path of directory in which to save plots. default: None Returns: - fulldir (Path): final path of the directory in which the figure is saved (may differ from input savedir, if datetime subfolder is added.) - savename (str): name under which the figure is saved """ stimstr_prs = [] for stimtype in corr_data["stim_order"]: stimstr_pr = sess_str_util.stim_par_str( stimtype, stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"], "print") stimstr_pr = stimstr_pr[:-1] if stimstr_pr[-1] == "s" else stimstr_pr stimstr_prs.append(stimstr_pr) dendstr_pr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print") sessstr = f"sess{sesspar['sess_n']}_{sesspar['plane']}" dendstr = sess_str_util.dend_par_str( analyspar["dend"], sesspar["plane"], extrapar["datatype"]) label_str = sess_str_util.fluor_par_str( analyspar["fluor"], str_type="print") lab_app = (f" ({analyspar['stats']} over " f"{stimpar['pre']}/{stimpar['post']} sec)") logger.info("Plotting pupil-ROI difference correlations for " "{} vs {}.".format(*corr_data["stim_order"])) # extract some info from sess_info keys = ["mouse_ns", "sess_ns", "lines", "planes"] [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys] n_sess = len(mouse_ns) nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="comma") if figpar is None: figpar = sess_plot_util.init_figpar() figpar = copy.deepcopy(figpar) if figpar["save"]["use_dt"] is None: figpar["save"]["use_dt"] = gen_util.create_time_str() figpar["init"]["sharex"] = True figpar["init"]["sharey"] = True fig, ax = plot_util.init_fig(n_sess, **figpar["init"]) suptitle = (u"Relationship between pupil diam. and {} changes, locked to " "unexpected events\n{} for each ROI ({} vs {})".format( label_str, lab_app, *corr_data["stim_order"])) for i, sess_roi_corrs in enumerate(corr_data["roi_corrs"]): sub_ax = plot_util.get_subax(ax, i) title = (f"Mouse {mouse_ns[i]} (sess {sess_ns[i]}, {lines[i]} " f"{planes[i]}{dendstr_pr}{nroi_strs[i]})") # top plot: correlations corr = f"Corr = {corr_data['corrs'][i]:.2f}" sub_ax.plot( sess_roi_corrs[0], sess_roi_corrs[1], marker=".", linestyle="None", label=corr) sub_ax.set_title(title, y=1.01) if sub_ax.is_last_row(): sub_ax.set_xlabel(f"{stimstr_prs[0].capitalize()} correlations") if sub_ax.is_first_col(): sub_ax.set_ylabel(f"{stimstr_prs[1].capitalize()} correlations") sub_ax.legend() plot_util.turn_off_extra(ax, n_sess) fig.suptitle(suptitle, fontsize="xx-large", y=1) if savedir is None: savedir = Path( figpar["dirs"]["roi"], figpar["dirs"]["pupil"]) savename = f"roi_diff_corrbyroi_{sessstr}{dendstr}" fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"]) return fulldir, savename
def plot_data_summ(plot_lines, data, stats, shuff_stats, title, savename, CI=0.95, q1v4=False, evu=False, comp="unexp", modif=False, no_save=False): """ plot_data_summ(plot_lines, data, stats, shuff_stats, title, savename) Plots summary data for a specific comparison, for each line and plane and saves figure. Required args: - plot_lines (pd DataFrame): DataFrame containing scores summary for specific comparison and criteria - data (str) : label of type of data to plot, e.g., "epoch_n" or "test_acc_bal" - stats (list) : list of stats to use for non shuffled data, e.g., ["mean", "sem", "sem"] - shuff_stats (list) : list of stats to use for shuffled data, e.g., ["median", "p2p5", "p97p5"] - title (str) : general plot titles (must contain "data") - savename (str) : plot save path Optional args: - CI (num) : CI for shuffled data (e.g., 0.95) default: 0.95 - q1v4 (bool) : if True, analysis is separated across first and last quartiles default: False - evu (bool) : if True, the first dataset will include expected sequences and the second will include unexpected sequences default: False - comp (str) : type of comparison default: "unexp" - modif (bool) : if True, plots are made in a modified (simplified way) default: False - no_save (bool): if True, figure is not saved default: False Returns: - fig (plt Figure) : figure """ celltypes = [[x, y] for y in ["dend", "soma"] for x in ["L23", "L5"]] max_sess = max(plot_lines["sess_n"].tolist()) fig, ax = init_res_fig(len(celltypes), max_sess, modif) if modif: fig.suptitle(title, y=1.0, weight="bold") n_vals = 5 # (mean/med, sem/2.5p, sem/97.5p, n_rois, n_runs) if data == "test_acc_bal": found = False for key in plot_lines.keys(): if data in key: found = True if not found: warnings.warn("test_acc_bal was not recorded", category=RuntimeWarning, stacklevel=1) return split_oris = sess_str_util.get_split_oris(comp) data_types = gen_util.list_if_not(data) if (q1v4 or evu or split_oris) and "test_acc" in data: ext_test = sess_str_util.ext_test_str(q1v4, evu, comp) n_vals = 8 # (extra mean/med, sem/2.5p, sem/97.5p for Q4) if data == "test_acc": data_types = ["test_acc", f"{ext_test}_acc"] elif data == "test_acc_bal": data_types = ["test_acc_bal", f"{ext_test}_acc_bal"] else: gen_util.accepted_values_error("data", data, ["test_acc", "test_acc_bal"]) for i, [line, plane] in enumerate(celltypes): sub_ax = plot_util.get_subax(ax, i) if not modif: sub_ax.set_xlim(-0.5, max_sess - 0.5) # get the right rows in dataframe cols = ["plane"] cri = [plane] curr_lines = gen_util.get_df_vals(plot_lines.loc[ plot_lines["line"].str.contains(line)], cols, cri) cri_str = ", ".join([f"{col}: {crit}" for col, crit in zip(cols, cri)]) if len(curr_lines) == 0: # no data warnings.warn(f"No data found for {line} {plane}, {cri_str}", category=RuntimeWarning, stacklevel=1) continue else: # shuffle or non shuffle missing skip = False for shuff in [False, True]: if shuff not in curr_lines["shuffle"].tolist(): warnings.warn(f"No shuffle={shuff} data found for {line} " f"{plane}, {cri_str}", category=RuntimeWarning, stacklevel=1) skip = True if skip: continue sess_ns = gen_util.get_df_vals(curr_lines, label="sess_n", dtype=int) mouse_ns = gen_util.get_df_vals(curr_lines, label="mouse_n", dtype=int) # mouse x sess x n_vals if -1 not in mouse_ns: raise RuntimeError("Shuffle data across mice is missing.") mouse_ns = gen_util.remove_if(mouse_ns, -1) data_arr = np.empty((len(mouse_ns), int(max_sess), n_vals)) * np.nan shuff_arr = np.empty((1, int(max_sess), n_vals - 1)) * np.nan for sess_n in sess_ns: sess_mice = gen_util.get_df_vals( curr_lines, "sess_n", sess_n, "mouse_n", dtype=int) for m, mouse_n in enumerate(mouse_ns + [-1]): if mouse_n not in sess_mice: continue if mouse_n == -1: stat_types = shuff_stats arr = shuff_arr m = 0 else: stat_types = stats arr = data_arr curr_line = gen_util.get_df_vals(curr_lines, ["sess_n", "mouse_n", "shuffle"], [sess_n, mouse_n, mouse_n==-1]) if len(curr_line) > 1: raise RuntimeError("Several lines correspond to criteria.") elif len(curr_line) == 0: continue for st, stat in enumerate(stat_types): for d, dat in enumerate(data_types): i = d * 3 + st arr[m, int(sess_n-1), i] = curr_line[f"{dat}_{stat}"] if mouse_n != -1: arr[m, int(sess_n-1), -2] = curr_line["n_rois"] arr[m, int(sess_n-1), -1] = curr_line["runs_total"] - \ curr_line["runs_nan"] summ_subplot(sub_ax, data_arr, shuff_arr, title, mouse_ns, sess_ns, line, plane, stats[0], stats[1], CI, q1v4, evu, split_oris, modif) if modif: n_sess_keep = 3 ylab = ax[0, 0].get_ylabel() ax[0, 0].set_ylabel("") sess_plot_util.format_linpla_subaxes(ax, ylab=ylab, xticks=np.arange(1, n_sess_keep + 1)) yticks = ax[0, 0].get_yticks() # always set ticks (even again) before setting labels ax[1, 0].set_yticks(yticks) ax[1, 0].set_yticklabels([int(v) for v in yticks], fontdict={"weight": "bold"}) if not no_save: fig.savefig(savename) return fig
def run_sess_lstm(sessid, args): if args.parallel and args.plt_bkend is not None: plt.switch_backend(args.plt_bkend) # needs to be repeated within joblib args.seed = rand_util.seed_all(args.seed, args.device, seed_torch=True) train_p = 0.8 lr = 1. * 10**(-args.lr_ex) if args.conv: conv_str = "_conv" outch_str = f"_{args.out_ch}outch" else: conv_str = "" outch_str = "" # Input output parameters n_stim_s = 0.6 n_roi_s = 0.3 # Stim/traces for training train_gabfr = 0 train_post = 0.9 # up to C roi_train_pre = 0 # from A stim_train_pre = 0.3 # from preceeding grayscreen # Stim/traces for testing (separated for unexp vs exp) test_gabfr = 3 test_post = 0.6 # up to grayscreen roi_test_pre = 0 # from D/U stim_test_pre = 0.3 # from preceeding C sess = sess_gen_util.init_sessions( sessid, args.datadir, args.mouse_df, args.runtype, full_table=False, fluor="dff", dend="extr", run=True, temp_log="warning")[0] analysdir = sess_gen_util.get_analysdir( sess.mouse_n, sess.sess_n, sess.plane, stimtype=args.stimtype, comp=None) dirname = Path(args.output, analysdir) file_util.createdir(dirname, log_dir=False) # Must not scale ROIs or running BEFOREHAND. Must do after to use only # network available data. # seq x frame x gabor x par logger.info("Preparing stimulus parameter dataframe", extra={"spacing": "\n"}) train_stim_wins, run_stats = sess_data_util.get_stim_data( sess, args.stimtype, n_stim_s, train_gabfr, stim_train_pre, train_post, gabk=16, run=True) logger.info("Adding ROI data") xran, train_roi_wins, roi_stats = sess_data_util.get_roi_data( sess, args.stimtype, n_roi_s, train_gabfr, roi_train_pre, train_post, gabk=16) logger.warning("Preparing windowed datasets (too slow - to be improved)") raise NotImplementedError("Not implemented properly - some error leads " "to excessive memory requests.") test_stim_wins = [] test_roi_wins = [] for unexp in [0, 1]: stim_wins = sess_data_util.get_stim_data( sess, args.stimtype, n_stim_s, test_gabfr, stim_test_pre, test_post, unexp, gabk=16, run_mean=run_stats[0], run_std=run_stats[1]) test_stim_wins.append(stim_wins) roi_wins = sess_data_util.get_roi_data(sess, args.stimtype, n_roi_s, test_gabfr, roi_test_pre, test_post, unexp, gabk=16, roi_means=roi_stats[0], roi_stds=roi_stats[1])[1] test_roi_wins.append(roi_wins) n_pars = train_stim_wins.shape[-1] # n parameters (121) n_rois = train_roi_wins.shape[-1] # n ROIs hyperstr = (f"{args.hidden_dim}hd_{args.num_layers}hl_{args.lr_ex}lrex_" f"{args.batchsize}bs{outch_str}{conv_str}") dls = data_util.create_dls(train_stim_wins, train_roi_wins, train_p=train_p, test_p=0, batchsize=args.batchsize, thresh_cl=0, train_shuff=True)[0] train_dl, val_dl, _ = dls test_dls = [] for s in [0, 1]: dl = data_util.init_dl(test_stim_wins[s], test_roi_wins[s], batchsize=args.batchsize) test_dls.append(dl) logger.info("Running LSTM") if args.conv: lstm = ConvPredROILSTM(args.hidden_dim, n_rois, out_ch=args.out_ch, num_layers=args.num_layers, dropout=args.dropout) else: lstm = PredLSTM(n_pars, args.hidden_dim, n_rois, num_layers=args.num_layers, dropout=args.dropout) lstm = lstm.to(args.device) lstm.loss_fn = torch.nn.MSELoss(size_average=False) lstm.opt = torch.optim.Adam(lstm.parameters(), lr=lr) loss_df = pd.DataFrame( np.nan, index=range(args.n_epochs), columns=["train", "val"]) min_val = np.inf for ep in range(args.n_epochs): logger.info(f"====> Epoch {ep}", extra={"spacing": "\n"}) if ep == 0: train_loss = run_dl(lstm, train_dl, args.device, train=False) else: train_loss = run_dl(lstm, train_dl, args.device, train=True) val_loss = run_dl(lstm, val_dl, args.device, train=False) loss_df["train"].loc[ep] = train_loss/train_dl.dataset.n_samples loss_df["val"].loc[ep] = val_loss/val_dl.dataset.n_samples logger.info(f"Training loss : {loss_df['train'].loc[ep]}") logger.info(f"Validation loss: {loss_df['val'].loc[ep]}") # record model if training is lower than val, and val reaches a new low if ep == 0 or val_loss < min_val: prev_model = glob.glob(str(Path(dirname, f"{hyperstr}_ep*.pth"))) prev_df = glob.glob(str(Path(dirname, f"{hyperstr}.csv"))) min_val = val_loss saved_ep = ep if len(prev_model) == 1 and len(prev_df) == 1: Path(prev_model[0]).unlink() Path(prev_df[0]).unlink() savename = f"{hyperstr}_ep{ep}" savefile = Path(dirname, savename) torch.save({"net": lstm.state_dict(), "opt": lstm.opt.state_dict()}, f"{savefile}.pth") file_util.saveinfo(loss_df, hyperstr, dirname, "csv") plot_util.linclab_plt_defaults(font=["Arial", "Liberation Sans"], fontdir=DEFAULT_FONTDIR) fig, ax = plt.subplots(1) for dataset in ["train", "val"]: plot_util.plot_traces(ax, range(args.n_epochs), np.asarray(loss_df[dataset]), label=dataset, title=f"Average loss (MSE) ({n_rois} ROIs)", xticks="auto") fig.savefig(Path(dirname, f"{hyperstr}_loss")) savemod = Path(dirname, f"{hyperstr}_ep{saved_ep}.pth") checkpoint = torch.load(savemod) lstm.load_state_dict(checkpoint["net"]) n_samples = 20 val_idx = np.random.choice(range(val_dl.dataset.n_samples), n_samples) val_samples = val_dl.dataset[val_idx] xrans = data_util.get_win_xrans(xran, val_samples[1].shape[1], val_idx.tolist()) fig, ax = plot_util.init_fig(n_samples, ncols=4, sharex=True, subplot_hei=2, subplot_wid=5) lstm.eval() with torch.no_grad(): batch_len, seq_len, n_items = val_samples[1].shape pred_tr = lstm(val_samples[0].transpose(1, 0).to(args.device)) pred_tr = pred_tr.view([seq_len, batch_len, n_items]).transpose(1, 0) for lab, data in zip(["target", "pred"], [val_samples[1], pred_tr]): data = data.numpy() for n in range(n_samples): roi_n = np.random.choice(range(data.shape[-1])) sub_ax = plot_util.get_subax(ax, n) plot_util.plot_traces(sub_ax, xrans[n], data[n, :, roi_n], label=lab, xticks="auto") plot_util.set_ticks(sub_ax, "x", xran[0], xran[-1], n=7) sess_plot_util.plot_labels(ax, train_gabfr, plot_vals="exp", pre=roi_train_pre, post=train_post) fig.suptitle(f"Target vs predicted validation traces ({n_rois} ROIs)") fig.savefig(Path(dirname, f"{hyperstr}_traces"))