def plot_from_dict(direc, plt_bkend=None, fontdir=None): """ plot_from_dict(direc) Plots data from dictionaries containing analysis parameters and results, or path to results. Required args: - direc (Path): path to directory in which dictionaries to plot data from are located Optional_args: - plt_bkend (str): mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : directory in which additional fonts are stored default: None """ logger.info(f"Plotting from hyperparameters in: {direc}", extra={"spacing": "\n"}) direc = Path(direc) plot_util.manage_mpl(plt_bkend, fontdir=fontdir) hyperpars = file_util.loadfile("hyperparameters.json", fulldir=direc) if "logregpar" in hyperpars.keys(): plot_traces_scores(hyperpars, savedir=direc) plot_util.cond_close_figs()
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, datetime=True, overwrite=False): """ plot_from_dict(dict_path) Plots data from dictionaries containing analysis parameters and results. Required args: - dict_path (Path): path to dictionary to plot data from Optional_args: - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : path to directory where additional fonts are stored default: None - parallel (bool) : if True, some of the analysis is parallelized across CPU cores default: False - datetime (bool) : figpar["save"] datatime parameter (whether to place figures in a datetime folder) default: True - overwrite (bool): figpar["save"] overwrite parameter (whether to overwrite figures) default: False """ logger.info(f"Plotting from dictionary: {dict_path}", extra={"spacing": "\n"}) figpar = sess_plot_util.init_figpar( plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, overwrite=overwrite ) plot_util.manage_mpl(cmap=False, **figpar["mng"]) dict_path = Path(dict_path) info = file_util.loadfile(dict_path) savedir = dict_path.parent analysis = info["extrapar"]["analysis"] # 0. Plots the correlation between pupil and roi/run changes for each # session if analysis == "c": # difference correlation plot_pup_diff_corr(figpar=figpar, savedir=savedir, **info) # difference correlation per ROI between stimuli elif analysis == "r": plot_pup_roi_stim_corr(figpar=figpar, savedir=savedir, **info) else: warnings.warn(f"No plotting function for analysis {analysis}", category=UserWarning, stacklevel=1) plot_util.cond_close_figs()
def run_analyses(sessions, analysis_dict, analyses, datatype="roi", seed=None, parallel=False): """ run_analyses(sessions, analysis_dict, analyses) Runs requested analyses on sessions using the parameters passed. Required args: - sessions (list) : list of sessions, possibly nested - analysis_dict (dict): analysis parameter dictionary (see init_param_cont()) - analyses (str) : analyses to run Optional args: - datatype (str) : datatype ("run", "roi") default: "roi" - seed (int) : seed to use default: None - parallel (bool): if True, some analyses are parallelized across CPU cores default: False """ if len(sessions) == 0: logger.warning("No sessions meet these criteria.") return # changes backend and defaults plot_util.manage_mpl(cmap=False, **analysis_dict["figpar"]["mng"]) sess_plot_util.update_plt_linpla() fct_dict = get_analysis_fcts() args_dict = copy.deepcopy(analysis_dict) for key, item in zip(["seed", "parallel", "datatype"], [seed, parallel, datatype]): args_dict[key] = item # run through analyses for analysis in analyses: if analysis not in fct_dict.keys(): raise ValueError(f"{analysis} analysis not found.") fct, datatype_req = fct_dict[analysis] if datatype not in datatype_req: continue args_dict_use = gen_util.keep_dict_keys( args_dict, inspect.getfullargspec(fct).args) fct(sessions=sessions, analysis=analysis, **args_dict_use) plot_util.cond_close_figs()
def run_analyses(sessions, analysis_dict, analyses, seed=None, parallel=False, plot_tc=True): """ run_analyses(sessions, analysis_dict, analyses) Runs requested analyses on sessions using the parameters passed. Required args: - sessions (list) : list of sessions, possibly nested - analysis_dict (dict): analysis parameter dictionary (see init_param_cont()) - analyses (str) : analyses to run Optional args: - seed (int) : seed to use default: None - parallel (bool): if True, some analyses are parallelized across CPU cores default: False - plot_tc (bool) : if True, tuning curves are plotted for each ROI default: True """ if len(sessions) == 0: logger.warning("No sessions meet these criteria.") return comp = True if isinstance(sessions[0], list) else False # changes backend and defaults plot_util.manage_mpl(cmap=False, **analysis_dict["figpar"]["mng"]) fct_dict = get_analysis_fcts() args_dict = copy.deepcopy(analysis_dict) for key, item in zip(["seed", "parallel", "plot_tc"], [seed, parallel, plot_tc]): args_dict[key] = item # run through analyses for analysis in analyses: if analysis not in fct_dict.keys(): raise ValueError(f"{analysis} analysis not found.") fct, comp_req = fct_dict[analysis] if comp_req != comp: continue args_dict_use = gen_util.keep_dict_keys( args_dict, inspect.getfullargspec(fct).args) fct(sessions=sessions, analysis=analysis, **args_dict_use) plot_util.cond_close_figs()
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, plot_tc=True, parallel=False, datetime=True, overwrite=False): """ plot_from_dict(dict_path) Plots data from dictionaries containing analysis parameters and results. Required args: - dict_path (Path): path to dictionary to plot data from Optional_args: - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : path to directory where additional fonts are stored default: None - plot_tc (bool) : if True, tuning curves are plotted for each ROI (dummy argument) default: True - parallel (bool) : if True, some of the analysis is parallelized across CPU cores default: False - datetime (bool) : figpar["save"] datatime parameter (whether to place figures in a datetime folder) default: True - overwrite (bool): figpar["save"] overwrite parameter (whether to overwrite figures) default: False """ logger.info(f"Plotting from dictionary: {dict_path}", extra={"spacing": "\n"}) figpar = sess_plot_util.init_figpar(plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, overwrite=overwrite) plot_util.manage_mpl(cmap=False, **figpar["mng"]) plt.rcParams["figure.titlesize"] = "xx-large" plt.rcParams["axes.titlesize"] = "xx-large" dict_path = Path(dict_path) info = file_util.loadfile(dict_path) savedir = dict_path.parent analysis = info["extrapar"]["analysis"] # 1. Plot average traces by quantile x unexpected for each session if analysis == "t": # traces gen_plots.plot_traces_by_qu_unexp_sess(figpar=figpar, savedir=savedir, modif=True, **info) # 2. Plot average traces by quantile, locked to unexpected for each session elif analysis == "l": # unexpected locked traces gen_plots.plot_traces_by_qu_lock_sess(figpar=figpar, savedir=savedir, modif=True, **info) else: warnings.warn(f"No modified plotting option for analysis {analysis}", category=UserWarning, stacklevel=1) plot_util.cond_close_figs()
def plot_from_dicts(direc, source="roi", plt_bkend=None, fontdir=None, plot_tc=True, parallel=False, datetime=True, overwrite=False, pattern="", depth=0): """ plot_from_dicts(direc) Plots data from dictionaries containing analysis parameters and results, or path to results. Required args: - direc (Path): path to directory in which dictionaries to plot data from are located or path to a single json file Optional_args: - source (str) : plotting source ("roi", "run", "gen", "pup", "modif", "logreg", "glm") - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : directory in which additional fonts are stored default: None - plot_tc (bool) : if True, tuning curves are plotted for each ROI default: True - parallel (bool) : if True, some of the analysis is parallelized across CPU cores default: False - datetime (bool) : figpar["save"] datatime parameter (whether to place figures in a datetime folder) default: True - overwrite (bool): figpar["save"] overwrite parameter (whether to overwrite figures) default: False - pattern (str) : pattern based on which to include json files in direc if direc is a directory default: "" - depth (int) : maximum depth at which to check for json files if direc is a directory default: 0 """ file_util.checkexists(direc) direc = Path(direc) if direc.is_dir(): if source == "logreg": targ_file = "hyperparameters.json" else: targ_file = "*.json" dict_paths = [] for d in range(depth + 1): dict_paths.extend( glob.glob(str(Path(direc, *(["*"] * d), targ_file)))) dict_paths = list(filter(lambda x: pattern in str(x), dict_paths)) if source == "logreg": dict_paths = [Path(dp).parent for dp in dict_paths] if len(dict_paths) == 0: raise OSError(f"No jsons found in {direc} at " f"depth {depth} with pattern '{pattern}'.") elif ".json" not in str(direc): raise ValueError("If providing a file, must be a json file.") else: if (source == "logreg" and not str(direc).endswith("hyperparameters.json")): raise ValueError("For logreg source, must provide path to " "a hyperparameters json file.") dict_paths = [direc] dict_paths = sorted(dict_paths) if len(dict_paths) > 1: logger.info(f"Plotting from {len(dict_paths)} dictionaries.") fontdir = Path(fontdir) if fontdir is not None else fontdir args_dict = { "plt_bkend": plt_bkend, "fontdir": fontdir, "plot_tc": plot_tc, "datetime": datetime, "overwrite": overwrite, } pass_parallel = True sources = [ "roi", "run", "gen", "modif", "pup", "logreg", "glm", "acr_sess" ] if source == "roi": fct = roi_plots.plot_from_dict elif source in ["run", "gen"]: fct = gen_plots.plot_from_dict elif source in ["pup", "pupil"]: fct = pup_plots.plot_from_dict elif source == "modif": fct = mod_plots.plot_from_dict elif source == "logreg": pass_parallel = False fct = logreg_plots.plot_from_dict elif source == "glm": pass_parallel = False fct = glm_plots.plot_from_dict elif source == "acr_sess": fct = acr_sess_plots.plot_from_dict else: gen_util.accepted_values_error("source", source, sources) args_dict = gen_util.keep_dict_keys(args_dict, inspect.getfullargspec(fct).args) gen_util.parallel_wrap(fct, dict_paths, args_dict=args_dict, parallel=parallel, pass_parallel=pass_parallel) plot_util.cond_close_figs()
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, datetime=True, overwrite=False): """ plot_from_dict(dict_path) Plots data from dictionaries containing analysis parameters and results. Required args: - dict_path (Path): path to dictionary to plot data from Optional_args: - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : path to directory where additional fonts are stored default: None - parallel (bool) : if True, some of the plotting is parallelized across CPU cores default: False - datetime (bool) : figpar["save"] datatime parameter (whether to place figures in a datetime folder) default: True - overwrite (bool): figpar["save"] overwrite parameter (whether to overwrite figures) default: False """ logger.info(f"Plotting from dictionary: {dict_path}", extra={"spacing": "\n"}) figpar = sess_plot_util.init_figpar( plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, overwrite=overwrite ) plot_util.manage_mpl(cmap=False, **figpar["mng"]) dict_path = Path(dict_path) info = file_util.loadfile(dict_path) savedir = dict_path.parent analysis = info["extrapar"]["analysis"] # 0. Plots the full traces for each session if analysis == "f": # full traces plot_full_traces(figpar=figpar, savedir=savedir, **info) # 1. Plot average traces by quantile x unexpected for each session elif analysis == "t": # traces plot_traces_by_qu_unexp_sess(figpar=figpar, savedir=savedir, **info) # 2. Plot average traces by quantile, locked to unexpected for each session elif analysis == "l": # unexpected locked traces plot_traces_by_qu_lock_sess(figpar=figpar, savedir=savedir, **info) # 3. Plot magnitude of change in dF/F area from first to last quantile of # unexpected vs no unexpected sequences, for each session elif analysis == "m": # mag plot_mag_change(figpar=figpar, savedir=savedir, **info) # 4. Plot autocorrelations elif analysis == "a": # autocorr plot_autocorr(figpar=figpar, savedir=savedir, **info) else: warnings.warn(f"No plotting function for analysis {analysis}", category=UserWarning, stacklevel=1) plot_util.cond_close_figs()
def run_regr(args): """ run_regr(args) Does runs of a logistic regressions on the specified comparison and range of sessions. Required args: - args (Argument parser): parser with analysis parameters as attributes: alg (str) : algorithm to use ("sklearn" or "pytorch") bal (bool) : if True, classes are balanced batchsize (int) : nbr of samples dataloader will load per batch (for "pytorch" alg) visflow_dir (str) : visual flow direction to analyse visflow_per (float) : number of seconds to include before visual flow segments visflow_size (int or list): visual flow square sizes to include comp (str) : type of comparison datadir (str) : data directory dend (str) : type of dendrites to use ("allen" or "dend") device (str) : device name (i.e., "cuda" or "cpu") ep_freq (int) : frequency at which to log loss to console error (str) : error to take, i.e., "std" (for std or quantiles) or "sem" (for SEM or MAD) fluor (str) : fluorescence trace type fontdir (str) : directory in which additional fonts are located gabfr (int) : gabor frame of reference if comparison is "unexp" gabk (int or list) : gabor kappas to include gab_ori (list or str) : gabor orientations to include incl (str or list) : sessions to include ("yes", "no", "all") lr (num) : model learning rate (for "pytorch" alg) mouse_n (int) : mouse number n_epochs (int) : number of epochs n_reg (int) : number of regular runs n_shuff (int) : number of shuffled runs scale (bool) : if True, each ROI is scaled output (str) : general directory in which to save output parallel (bool) : if True, runs are done in parallel plt_bkend (str) : pyplot backend to use q1v4 (bool) : if True, analysis is trained on first and tested on last quartiles exp_v_unexp (bool) : if True, analysis is trained on expected and tested on unexpected sequences runtype (str) : type of run ("prod" or "pilot") seed (int) : seed to seed random processes with sess_n (int) : session number stats (str) : stats to take, i.e., "mean" or "median" stimtype (str) : stim to analyse ("gabors" or "visflow") train_p (list) : proportion of dataset to allocate to training uniqueid (str or int) : unique ID for analysis wd (float) : weight decay value (for "pytorch" arg) """ args = copy.deepcopy(args) if args.datadir is None: args.datadir = DEFAULT_DATADIR else: args.datadir = Path(args.datadir) if args.uniqueid == "datetime": args.uniqueid = gen_util.create_time_str() elif args.uniqueid in ["None", "none"]: args.uniqueid = None reseed = False if args.seed in [None, "None"]: reseed = True # deal with parameters extrapar = {"uniqueid": args.uniqueid, "seed": args.seed} techpar = { "reseed": reseed, "device": args.device, "alg": args.alg, "parallel": args.parallel, "plt_bkend": args.plt_bkend, "fontdir": args.fontdir, "output": args.output, "ep_freq": args.ep_freq, "n_reg": args.n_reg, "n_shuff": args.n_shuff, } mouse_df = DEFAULT_MOUSE_DF_PATH stimpar = logreg.get_stimpar(args.comp, args.stimtype, args.visflow_dir, args.visflow_size, args.gabfr, args.gabk, gab_ori=args.gab_ori, visflow_pre=args.visflow_pre) analyspar = sess_ntuple_util.init_analyspar(args.fluor, stats=args.stats, error=args.error, scale=not (args.no_scale), dend=args.dend) if args.q1v4: quantpar = sess_ntuple_util.init_quantpar(4, [0, -1]) else: quantpar = sess_ntuple_util.init_quantpar(1, 0) logregpar = sess_ntuple_util.init_logregpar(args.comp, not (args.not_ctrl), args.q1v4, args.exp_v_unexp, args.n_epochs, args.batchsize, args.lr, args.train_p, args.wd, args.bal, args.alg) omit_sess, omit_mice = sess_gen_util.all_omit(stimpar.stimtype, args.runtype, stimpar.visflow_dir, stimpar.visflow_size, stimpar.gabk) sessids = sorted( sess_gen_util.get_sess_vals(mouse_df, "sessid", args.mouse_n, args.sess_n, args.runtype, incl=args.incl, omit_sess=omit_sess, omit_mice=omit_mice)) if len(sessids) == 0: logger.warning( f"No sessions found (mouse: {args.mouse_n}, sess: {args.sess_n}, " f"runtype: {args.runtype})") for sessid in sessids: sess = sess_gen_util.init_sessions(sessid, args.datadir, mouse_df, args.runtype, full_table=False, fluor=analyspar.fluor, dend=analyspar.dend, temp_log="warning")[0] logreg.run_regr(sess, analyspar, stimpar, logregpar, quantpar, extrapar, techpar) plot_util.cond_close_figs()
def plot_summ(output, savename, stimtype="gabors", comp="unexp", ctrl=False, visflow_dir="both", fluor="dff", scale=True, CI=0.95, plt_bkend=None, fontdir=None, modif=False): """ plot_summ(output) Plots summary data for a specific comparison, for each datatype in a separate figure and saves figures. Required args: - output (str) : general directory in which summary dataframe is saved (runtype and q1v4 values are inferred from the directory name) - savename (str): name of the dataframe containing summary data to plot Optional args: - stimtype (str) : stimulus type default: "gabors" - comp (str) : type of comparison default: "unexp" - ctrl (bool) : if True, control comparisons are analysed default: False - visflow_dir (str): visual flow direction default: "both" - fluor (str) : fluorescence trace type default: "dff" - scale (bool) : whether ROIs are scaled default: True - CI (num) : CI for shuffled data default: 0.95 - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (str) : directory in which additional fonts are located default: None - modif (bool) : if True, plots are made in a modified (simplified way) default: False """ plot_util.manage_mpl(plt_bkend, fontdir=fontdir) summ_scores_file = Path(output, savename) if summ_scores_file.is_file(): summ_scores = file_util.loadfile(summ_scores_file) else: warnings.warn(f"{summ_scores_file} not found.", category=RuntimeWarning, stacklevel=1) return if len(summ_scores) == 0: warnings.warn(f"No data in {summ_scores_file}.", category=RuntimeWarning, stacklevel=1) return # drop NaN lines summ_scores = summ_scores.loc[~summ_scores["epoch_n_mean"].isna()] data_types = ["epoch_n", "test_acc", "test_acc_bal"] data_titles = ["Epoch nbrs", "Test accuracy", "Test accuracy (balanced)"] stats = ["mean", "sem", "sem"] shuff_stats = ["median"] + math_util.get_percentiles(CI)[1] q1v4, evu = False, False if "_q1v4" in str(output): q1v4 = True elif "_evu" in str(output): evu = True runtype = "prod" if "pilot" in str(output): runtype = "pilot" if stimtype == "gabors": visflow_dir = "none" stim_str = "gab" stim_str_pr = "gabors" else: visflow_dir = sess_gen_util.get_params(stimtype, visflow_dir)[0] if isinstance(visflow_dir, list) and len(visflow_dir) == 2: visflow_dir = "both" stim_str = sess_str_util.dir_par_str(visflow_dir, str_type="file") stim_str_pr = sess_str_util.dir_par_str(visflow_dir, str_type="print") scale_str = sess_str_util.scale_par_str(scale, "file") scale_str_pr = sess_str_util.scale_par_str(scale, "file").replace("_", " ") ctrl_str = sess_str_util.ctrl_par_str(ctrl) ctrl_str_pr = sess_str_util.ctrl_par_str(ctrl, str_type="print") modif_str = "_modif" if modif else "" save_dir = Path(output, f"figures_{fluor}") save_dir.mkdir(exist_ok=True) cols = ["scale", "fluor", "visflow_dir", "runtype"] cri = [scale, fluor, visflow_dir, runtype] plot_lines = gen_util.get_df_vals(summ_scores, cols, cri) cri_str = ", ".join([f"{col}: {crit}" for col, crit in zip(cols, cri)]) if len(plot_lines) == 0: # no data warnings.warn(f"No data found for {cri_str}", category=RuntimeWarning, stacklevel=1) return else: # shuffle or non shuffle missing skip = False for shuff in [False, True]: if shuff not in plot_lines["shuffle"].tolist(): warnings.warn(f"No shuffle={shuff} data found for {cri_str}", category=RuntimeWarning, stacklevel=1) skip = True if skip: return for data, data_title in zip(data_types, data_titles): if not modif: title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr} - " f"{data_title} for log regr on\n" + u"{} {} ".format(scale_str_pr, fluor) + f"data ({runtype})") else: title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr}\n" f"{data_title}") if "_" in title: title = title.replace("_", " ") savename = (f"{data}_{stim_str}_{comp}{ctrl_str}{scale_str}" f"{modif_str}.svg") full_savename = Path(save_dir, savename) plot_data_summ(plot_lines, data, stats, shuff_stats, title, full_savename, CI, q1v4, evu, comp, modif) plot_util.cond_close_figs()
def main(args): """ main(args) Runs analyses with parser arguments. Required args: - args (dict): parser argument dictionary """ # set logger to the specified level logger_util.set_level(level=args.log_level) if args.datadir is None: args.datadir = DEFAULT_DATADIR else: args.datadir = Path(args.datadir) args.mouse_df_path = DEFAULT_MOUSE_DF_PATH # Directory with additional fonts args.fontdir = DEFAULT_FONTDIR if DEFAULT_FONTDIR.exists() else None # warn if parallel is not used if args.overwrite and not (args.plot_only): if not args.parallel: warnings.warn( "Unless memory demands are too high for the machine being " "used, it is strongly recommended that paper analyses be run " "with the '--parallel' argument (enables computations to be " "distributed across available CPU cores). Otherwise, analyses " "may be very slow.", category=UserWarning, stacklevel=1) time.sleep(paper_organization.WARNING_SLEEP) # run through figure(s) and panel(s) if args.figure == "all": figures = paper_organization.get_all_figures() else: figures = [args.figure] sessions = None panel = args.panel for args.figure in figures: if panel == "all": panels = paper_organization.get_all_panels(args.figure) else: panels = [panel] for p, args.panel in enumerate(panels): new_fig = (p == 0) try: with gen_util.TimeIt(): sessions = run_single_panel(args, sessions=sessions, new_fig=new_fig) except Exception as err: sep = DOUBLE_SEP if new_fig else SEP if "Cannot plot figure panel" in str(err): lead = f"{sep}Fig. {args.figure}{args.panel.upper()}" logger.info(f"{lead}. {err}") else: raise err plot_util.cond_close_figs()