def reformat_args(args): """ reformat_args(args) Returns reformatted args for analyses, specifically - Sets stimulus parameters to "none" if they are irrelevant to the stimtype - Changes stimulus parameters from "both" to actual values - Sets seed, though doesn't seed - Modifies analyses (if "all" or "all_" in parameter) Adds the following args: - dend (str) : type of dendrites to use ("allen", "extr") - omit_sess (str): sess to omit - omit_mice (str): mice to omit Required args: - args (Argument parser): parser with the following attributes: runtype (str) : runtype ("pilot" or "prod") stimtype (str) : stimulus to analyse (visflow or gabors) Returns: - args (Argument parser): input parser, with the following attributes added: visflow_dir, visflow_size, gabfr, gabk, gab_oriomit_sess, omit_mice, dend, analyses, seed """ args = copy.deepcopy(args) [args.visflow_dir, args.visflow_size, args.gabfr, args.gabk, args.gab_ori] = sess_gen_util.get_params(args.stimtype, "both", 128, "any", 16, "any") if args.plane == "soma": args.dend = "allen" args.omit_sess, args.omit_mice = sess_gen_util.all_omit( args.stimtype, args.runtype, args.visflow_dir, args.visflow_size, args.gabk) # choose a seed if none is provided (i.e., args.seed=-1), but seed later args.seed = rand_util.seed_all(args.seed, "cpu", log_seed=False, seed_now=False) # collect analysis letters all_analyses = "".join(get_analysis_fcts().keys()) if "all" in args.analyses: if "_" in args.analyses: excl = args.analyses.split("_")[1] args.analyses, _ = gen_util.remove_lett(all_analyses, excl) else: args.analyses = all_analyses elif "_" in args.analyses: raise ValueError("Use '_' in args.analyses only with 'all'.") return args
def _set_seed(self): """ self._set_seed() Updates attributes related to random process seeding. Updates the following attributes: - paper_seed (bool): whether the paper seed will be used - seed (bool): specific seed that will be used None, if self.randomness is False """ if not hasattr(self, "n_perms"): raise RuntimeError("Must run self._set_power() first.") if not self.randomness: self.seed = None self.paper_seed = True else: if self.seed == "paper": self.seed = PAPER_SEED self.paper_seed = True else: self.seed = int(self.seed) self.paper_seed = False if self.seed == -1: self.seed = PAPER_SEED # select any seed but the paper seed while self.seed == PAPER_SEED: self.seed = rand_util.seed_all(-1, "cpu", log_seed=False, seed_now=False) if self.seed != PAPER_SEED: self.warnings.append(seed_warning(self.seed))
def run_mag_change(sessions, analysis, seed, analyspar, sesspar, stimpar, permpar, quantpar, figpar, datatype="roi"): """ run_mag_change(sessions, analysis, seed, analyspar, sesspar, stimpar, permpar, quantpar, figpar) Calculates and plots the magnitude of change in activity of ROIs between the first and last quantile for expected vs unexpected sequences. Saves results and parameters relevant to analysis in a dictionary. Required args: - sessions (list) : list of Session objects - analysis (str) : analysis type (e.g., "m") - seed (int) : seed value to use. (-1 treated as None) - analyspar (AnalysPar): named tuple containing analysis parameters - sesspar (SessPar) : named tuple containing session parameters - stimpar (StimPar) : named tuple containing stimulus parameters - permpar (PermPar) : named tuple containing permutation parameters - quantpar (QuantPar) : named tuple containing quantile analysis parameters - figpar (dict) : dictionary containing figure parameters Optional args: - datatype (str): type of data (e.g., "roi", "run") """ sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype, sesspar.plane, stimpar.visflow_dir, stimpar.visflow_size, stimpar.gabk, "print") dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane, datatype, "print") datastr = sess_str_util.datatype_par_str(datatype) logger.info( f"Calculating and plotting the magnitude changes in {datastr} " f"activity across quantiles \n({sessstr_pr}{dendstr_pr}).", extra={"spacing": "\n"}) if permpar.multcomp: permpar = sess_ntuple_util.get_modif_ntuple(permpar, "multcomp", len(sessions)) # get full data: session x unexp x quants of interest x [ROI x seq] integ_info = quant_analys.trace_stats_by_qu_sess(sessions, analyspar, stimpar, quantpar.n_quants, quantpar.qu_idx, by_exp=True, integ=True, ret_arr=True, datatype=datatype) all_counts = integ_info[-2] qu_data = integ_info[-1] # extract session info mouse_ns = [sess.mouse_n for sess in sessions] lines = [sess.line for sess in sessions] if analyspar.rem_bad: nanpol = None else: nanpol = "omit" seed = rand_util.seed_all(seed, "cpu", log_seed=False) mags = quant_analys.qu_mags(qu_data, permpar, mouse_ns, lines, analyspar.stats, analyspar.error, nanpol=nanpol, op_qu="diff", op_unexp="diff") # convert mags items to list mags = copy.deepcopy(mags) mags["all_counts"] = all_counts for key in ["mag_st", "L2", "mag_rel_th", "L2_rel_th"]: mags[key] = mags[key].tolist() sess_info = sess_gen_util.get_sess_info(sessions, analyspar.fluor, incl_roi=(datatype == "roi"), rem_bad=analyspar.rem_bad) extrapar = {"analysis": analysis, "datatype": datatype, "seed": seed} info = { "analyspar": analyspar._asdict(), "sesspar": sesspar._asdict(), "stimpar": stimpar._asdict(), "extrapar": extrapar, "permpar": permpar._asdict(), "quantpar": quantpar._asdict(), "mags": mags, "sess_info": sess_info } fulldir, savename = gen_plots.plot_mag_change(figpar=figpar, **info) file_util.saveinfo(info, savename, fulldir, "json")
def run_traces_by_qu_lock_sess(sessions, analysis, seed, analyspar, sesspar, stimpar, quantpar, figpar, datatype="roi"): """ run_traces_by_qu_lock_sess(sessions, analysis, analyspar, sesspar, stimpar, quantpar, figpar) Retrieves trace statistics by session x quantile at the transition of expected to unexpected sequences (or v.v.) and plots traces across ROIs by quantile with each session in a separate subplot. Also runs analysis for one quantile (full data) with different unexpected lengths grouped separated Saves results and parameters relevant to analysis in a dictionary. Required args: - sessions (list) : list of Session objects - analysis (str) : analysis type (e.g., "l") - seed (int) : seed value to use. (-1 treated as None) - analyspar (AnalysPar): named tuple containing analysis parameters - sesspar (SessPar) : named tuple containing session parameters - stimpar (StimPar) : named tuple containing stimulus parameters - quantpar (QuantPar) : named tuple containing quantile analysis parameters - figpar (dict) : dictionary containing figure parameters Optional args: - datatype (str): type of data (e.g., "roi", "run") """ sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype, sesspar.plane, stimpar.visflow_dir, stimpar.visflow_size, stimpar.gabk, "print") dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane, datatype, "print") datastr = sess_str_util.datatype_par_str(datatype) logger.info( f"Analysing and plotting unexpected vs expected {datastr} " f"traces locked to unexpected onset by quantile ({quantpar.n_quants}) " f"\n({sessstr_pr}{dendstr_pr}).", extra={"spacing": "\n"}) seed = rand_util.seed_all(seed, "cpu", log_seed=False) # modify quantpar to retain all quantiles quantpar_one = sess_ntuple_util.init_quantpar(1, 0) n_quants = quantpar.n_quants quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all") if stimpar.stimtype == "visflow": pre_post = [2.0, 6.0] elif stimpar.stimtype == "gabors": pre_post = [2.0, 8.0] else: gen_util.accepted_values_error("stimpar.stimtype", stimpar.stimtype, ["visflow", "gabors"]) logger.warning("Setting pre to {}s and post to {}s.".format(*pre_post)) stimpar = sess_ntuple_util.get_modif_ntuple(stimpar, ["pre", "post"], pre_post) figpar = copy.deepcopy(figpar) if figpar["save"]["use_dt"] is None: figpar["save"]["use_dt"] = gen_util.create_time_str() for baseline in [None, stimpar.pre]: basestr_pr = sess_str_util.base_par_str(baseline, "print") for quantpar in [quantpar_one, quantpar_mult]: locks = ["unexp", "exp"] if quantpar.n_quants == 1: locks.append("unexp_split") # get the stats (all) separating by session and quantiles for lock in locks: logger.info( f"{quantpar.n_quants} quant, {lock} lock{basestr_pr}", extra={"spacing": "\n"}) if lock == "unexp_split": trace_info = quant_analys.trace_stats_by_exp_len_sess( sessions, analyspar, stimpar, quantpar.n_quants, quantpar.qu_idx, byroi=False, nan_empty=True, baseline=baseline, datatype=datatype) else: trace_info = quant_analys.trace_stats_by_qu_sess( sessions, analyspar, stimpar, quantpar.n_quants, quantpar.qu_idx, byroi=False, lock=lock, nan_empty=True, baseline=baseline, datatype=datatype) # for comparison, locking to middle of expected sample (1 quant) exp_samp = quant_analys.trace_stats_by_qu_sess( sessions, analyspar, stimpar, quantpar_one.n_quants, quantpar_one.qu_idx, byroi=False, lock="exp_samp", nan_empty=True, baseline=baseline, datatype=datatype) extrapar = { "analysis": analysis, "datatype": datatype, "seed": seed, } xrans = [xran.tolist() for xran in trace_info[0]] all_stats = [sessst.tolist() for sessst in trace_info[1]] exp_stats = [expst.tolist() for expst in exp_samp[1]] trace_stats = { "xrans": xrans, "all_stats": all_stats, "all_counts": trace_info[2], "lock": lock, "baseline": baseline, "exp_stats": exp_stats, "exp_counts": exp_samp[2] } if lock == "unexp_split": trace_stats["unexp_lens"] = trace_info[3] sess_info = sess_gen_util.get_sess_info( sessions, analyspar.fluor, incl_roi=(datatype == "roi"), rem_bad=analyspar.rem_bad) info = { "analyspar": analyspar._asdict(), "sesspar": sesspar._asdict(), "stimpar": stimpar._asdict(), "quantpar": quantpar._asdict(), "extrapar": extrapar, "sess_info": sess_info, "trace_stats": trace_stats } fulldir, savename = gen_plots.plot_traces_by_qu_lock_sess( figpar=figpar, **info) file_util.saveinfo(info, savename, fulldir, "json")
def run_sess_logreg(sess, analyspar, stimpar, logregpar, n_splits=100, n_shuff_splits=300, seed=None, parallel=False): """ run_sess_logreg(sess, analyspar, stimpar, logregpar) Runs logistic regressions on a session (real data and shuffled), and returns statistics dataframes. Required args: - sess (Session): Session object - analyspar (AnalysPar): named tuple containing analysis parameters - stimpar (StimPar): named tuple containing stimulus parameters - logregpar (LogRegPar): named tuple containing logistic regression parameters Optional args: - n_splits (int): number of data splits to run logistic regressions on default: 100 - n_shuff_splits (int): number of shuffled data splits to run logistic regressions on default: 300 - seed (int): seed value to use. (-1 treated as None) default: None - parallel (bool): if True, some of the analysis is run in parallel across CPU cores default: False Returns: - data_stats_df (pd.DataFrame): dataframe with only one data row containing data stats for each score and data subset. - shuffle_df (pd.DataFrame): dataframe where each row contains data for different data shuffles, and each column contains data for each score and data subset. """ seed = rand_util.seed_all(seed, log_seed=False, seed_now=False) # retrieve data input_data, target_data, ctrl_ns = get_decoding_data(sess, analyspar, stimpar, comp=logregpar.comp, ctrl=logregpar.ctrl) scores_df = misc_analys.get_check_sess_df([sess], None, analyspar) common_columns = scores_df.columns.tolist() logreg_columns = ["comp", "ctrl", "bal", "shuffle"] # do checks if logregpar.q1v4 or logregpar.exp_v_unexp: raise NotImplementedError("q1v4 and exp_v_unexp are not implemented.") if n_splits <= 0 or n_shuff_splits <= 0: raise ValueError("n_splits and n_shuff_splits must be greater than 0.") set_types = ["train", "test"] score_types = ["neg_log_loss", "accuracy", "balanced_accuracy"] set_score_types = list(itertools.product(set_types, score_types)) extrapar = dict() for shuffle in [False, True]: n_runs = n_shuff_splits if shuffle else n_splits extrapar["shuffle"] = shuffle temp_dfs = [] for b, n in enumerate(range(0, n_runs, MAX_SIMULT_RUNS)): extrapar["n_runs"] = int(np.min([MAX_SIMULT_RUNS, n_runs - n])) with logger_util.TempChangeLogLevel(level="warning"): mod_cvs, _, _ = logreg_util.run_logreg_cv_sk( input_data, target_data, logregpar._asdict(), extrapar, analyspar.scale, ctrl_ns, randst=seed + b, parallel=parallel, save_models=False, catch_set_prob=False) temp_df = pd.DataFrame() for set_type, score_type in set_score_types: key = f"{set_type}_{score_type}" temp_df[key] = mod_cvs[key] temp_dfs.append(temp_df) # compile batch scores, and get session stats for non shuffled data temp_df = pd.concat(temp_dfs, ignore_index=True) if not shuffle: temp_df = get_df_stats(temp_df, analyspar) # add columns to df score_columns = temp_df.columns.tolist() for col in common_columns: temp_df[col] = scores_df.loc[0, col] for col in logreg_columns: if col != "shuffle": temp_df[col] = logregpar._asdict()[col] else: temp_df[col] = shuffle # re-sort columns temp_df = temp_df.reindex(common_columns + logreg_columns + score_columns, axis=1) if shuffle: shuffle_df = temp_df else: data_stats_df = temp_df return data_stats_df, shuffle_df
def reformat_args(args): """ reformat_args(args) Returns reformatted args for analyses, specifically - Sets stimulus parameters to "none" if they are irrelevant to the stimtype - Changes stimulus parameters from "both" to actual values - Sets seed, though doesn't seed - Modifies analyses (if "all" or "all_" in parameter) Adds the following args: - omit_sess (str): sess to omit - omit_mice (str): mice to omit Required args: - args (Argument parser): parser with the following attributes: visflow_dir (str) : visual flow direction values to include (e.g., "right", "left" or "both") visflow_size (int or str): visual flow size values to include (e.g., 128, 256, "both") gabfr (int) : gabor frame value to start sequences at (e.g., 0, 1, 2, 3) gabk (int or str) : gabor kappa values to include (e.g., 4, 16 or "both") gab_ori (int or str) : gabor orientation values to include (e.g., 0, 45, 90, 135, 180, 225 or "all") runtype (str) : runtype ("pilot" or "prod") stimtype (str) : stimulus to analyse (visflow or gabors) Returns: - args (Argument parser): input parser, with the following attributes modified: visflow_dir, visflow_size, gabfr, gabk, gab_ori, grps, analyses, seed and the following attributes added: omit_sess, omit_mice """ args = copy.deepcopy(args) [args.visflow_dir, args.visflow_size, args.gabfr, args.gabk, args.gab_ori] = sess_gen_util.get_params(args.stimtype, args.visflow_dir, args.visflow_size, args.gabfr, args.gabk, args.gab_ori) args.omit_sess, args.omit_mice = sess_gen_util.all_omit( args.stimtype, args.runtype, args.visflow_dir, args.visflow_size, args.gabk) # choose a seed if none is provided (i.e., args.seed=-1), but seed later args.seed = rand_util.seed_all(args.seed, "cpu", log_seed=False, seed_now=False) # collect analysis letters all_analyses = "".join(get_analysis_fcts().keys()) if "all" in args.analyses: if "_" in args.analyses: excl = args.analyses.split("_")[1] args.analyses, _ = gen_util.remove_lett(all_analyses, excl) else: args.analyses = all_analyses elif "_" in args.analyses: raise ValueError("Use '_' in args.analyses only with 'all'.") return args
def run_glms(sessions, analysis, seed, analyspar, sesspar, stimpar, glmpar, figpar, parallel=False): """ run_glms(sessions, analysis, seed, analyspar, sesspar, stimpar, glmpar, figpar) """ seed = rand_util.seed_all(seed, "cpu", log_seed=False) sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype, sesspar.plane, stimpar.visflow_dir, stimpar.visflow_size, stimpar.gabk, "print") dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane, "roi", "print") logger.info( "Analysing and plotting explained variance in ROI activity " f"({sessstr_pr}{dendstr_pr}).", extra={"spacing": "\n"}) if glmpar.each_roi: # must do each session separately glm_type = "per_ROI_per_sess" sess_batches = sessions logger.info("Per ROI, each session separately.") else: glm_type = "across_sess" sess_batches = [sessions] logger.info(f"Across ROIs, {len(sessions)} sessions together.") # optionally runs in parallel, or propagates parallel to next level parallel_here = (parallel and not (glmpar.each_roi) and (len(sess_batches) != 1)) parallel_after = True if (parallel and not (parallel_here)) else False args_list = [analyspar, sesspar, stimpar, glmpar] args_dict = { "parallel": parallel_after, # proactively set next parallel "seed": seed, } all_expl_var = gen_util.parallel_wrap(run_glm, sess_batches, args_list, args_dict, parallel=parallel_here) if glmpar.each_roi: sessions = sess_batches else: sessions = sess_batches[0] sess_info = sess_gen_util.get_sess_info(sessions, analyspar.fluor, rem_bad=analyspar.rem_bad) extrapar = { "analysis": analysis, "seed": seed, "glm_type": glm_type, } info = { "analyspar": analyspar._asdict(), "sesspar": sesspar._asdict(), "stimpar": stimpar._asdict(), "glmpar": glmpar._asdict(), "extrapar": extrapar, "all_expl_var": all_expl_var, "sess_info": sess_info } fulldir, savename = glm_plots.plot_glm_expl_var(figpar=figpar, **info) file_util.saveinfo(info, savename, fulldir, "json") return
def run_sess_lstm(sessid, args): if args.parallel and args.plt_bkend is not None: plt.switch_backend(args.plt_bkend) # needs to be repeated within joblib args.seed = rand_util.seed_all(args.seed, args.device, seed_torch=True) train_p = 0.8 lr = 1. * 10**(-args.lr_ex) if args.conv: conv_str = "_conv" outch_str = f"_{args.out_ch}outch" else: conv_str = "" outch_str = "" # Input output parameters n_stim_s = 0.6 n_roi_s = 0.3 # Stim/traces for training train_gabfr = 0 train_post = 0.9 # up to C roi_train_pre = 0 # from A stim_train_pre = 0.3 # from preceeding grayscreen # Stim/traces for testing (separated for unexp vs exp) test_gabfr = 3 test_post = 0.6 # up to grayscreen roi_test_pre = 0 # from D/U stim_test_pre = 0.3 # from preceeding C sess = sess_gen_util.init_sessions( sessid, args.datadir, args.mouse_df, args.runtype, full_table=False, fluor="dff", dend="extr", run=True, temp_log="warning")[0] analysdir = sess_gen_util.get_analysdir( sess.mouse_n, sess.sess_n, sess.plane, stimtype=args.stimtype, comp=None) dirname = Path(args.output, analysdir) file_util.createdir(dirname, log_dir=False) # Must not scale ROIs or running BEFOREHAND. Must do after to use only # network available data. # seq x frame x gabor x par logger.info("Preparing stimulus parameter dataframe", extra={"spacing": "\n"}) train_stim_wins, run_stats = sess_data_util.get_stim_data( sess, args.stimtype, n_stim_s, train_gabfr, stim_train_pre, train_post, gabk=16, run=True) logger.info("Adding ROI data") xran, train_roi_wins, roi_stats = sess_data_util.get_roi_data( sess, args.stimtype, n_roi_s, train_gabfr, roi_train_pre, train_post, gabk=16) logger.warning("Preparing windowed datasets (too slow - to be improved)") raise NotImplementedError("Not implemented properly - some error leads " "to excessive memory requests.") test_stim_wins = [] test_roi_wins = [] for unexp in [0, 1]: stim_wins = sess_data_util.get_stim_data( sess, args.stimtype, n_stim_s, test_gabfr, stim_test_pre, test_post, unexp, gabk=16, run_mean=run_stats[0], run_std=run_stats[1]) test_stim_wins.append(stim_wins) roi_wins = sess_data_util.get_roi_data(sess, args.stimtype, n_roi_s, test_gabfr, roi_test_pre, test_post, unexp, gabk=16, roi_means=roi_stats[0], roi_stds=roi_stats[1])[1] test_roi_wins.append(roi_wins) n_pars = train_stim_wins.shape[-1] # n parameters (121) n_rois = train_roi_wins.shape[-1] # n ROIs hyperstr = (f"{args.hidden_dim}hd_{args.num_layers}hl_{args.lr_ex}lrex_" f"{args.batchsize}bs{outch_str}{conv_str}") dls = data_util.create_dls(train_stim_wins, train_roi_wins, train_p=train_p, test_p=0, batchsize=args.batchsize, thresh_cl=0, train_shuff=True)[0] train_dl, val_dl, _ = dls test_dls = [] for s in [0, 1]: dl = data_util.init_dl(test_stim_wins[s], test_roi_wins[s], batchsize=args.batchsize) test_dls.append(dl) logger.info("Running LSTM") if args.conv: lstm = ConvPredROILSTM(args.hidden_dim, n_rois, out_ch=args.out_ch, num_layers=args.num_layers, dropout=args.dropout) else: lstm = PredLSTM(n_pars, args.hidden_dim, n_rois, num_layers=args.num_layers, dropout=args.dropout) lstm = lstm.to(args.device) lstm.loss_fn = torch.nn.MSELoss(size_average=False) lstm.opt = torch.optim.Adam(lstm.parameters(), lr=lr) loss_df = pd.DataFrame( np.nan, index=range(args.n_epochs), columns=["train", "val"]) min_val = np.inf for ep in range(args.n_epochs): logger.info(f"====> Epoch {ep}", extra={"spacing": "\n"}) if ep == 0: train_loss = run_dl(lstm, train_dl, args.device, train=False) else: train_loss = run_dl(lstm, train_dl, args.device, train=True) val_loss = run_dl(lstm, val_dl, args.device, train=False) loss_df["train"].loc[ep] = train_loss/train_dl.dataset.n_samples loss_df["val"].loc[ep] = val_loss/val_dl.dataset.n_samples logger.info(f"Training loss : {loss_df['train'].loc[ep]}") logger.info(f"Validation loss: {loss_df['val'].loc[ep]}") # record model if training is lower than val, and val reaches a new low if ep == 0 or val_loss < min_val: prev_model = glob.glob(str(Path(dirname, f"{hyperstr}_ep*.pth"))) prev_df = glob.glob(str(Path(dirname, f"{hyperstr}.csv"))) min_val = val_loss saved_ep = ep if len(prev_model) == 1 and len(prev_df) == 1: Path(prev_model[0]).unlink() Path(prev_df[0]).unlink() savename = f"{hyperstr}_ep{ep}" savefile = Path(dirname, savename) torch.save({"net": lstm.state_dict(), "opt": lstm.opt.state_dict()}, f"{savefile}.pth") file_util.saveinfo(loss_df, hyperstr, dirname, "csv") plot_util.linclab_plt_defaults(font=["Arial", "Liberation Sans"], fontdir=DEFAULT_FONTDIR) fig, ax = plt.subplots(1) for dataset in ["train", "val"]: plot_util.plot_traces(ax, range(args.n_epochs), np.asarray(loss_df[dataset]), label=dataset, title=f"Average loss (MSE) ({n_rois} ROIs)", xticks="auto") fig.savefig(Path(dirname, f"{hyperstr}_loss")) savemod = Path(dirname, f"{hyperstr}_ep{saved_ep}.pth") checkpoint = torch.load(savemod) lstm.load_state_dict(checkpoint["net"]) n_samples = 20 val_idx = np.random.choice(range(val_dl.dataset.n_samples), n_samples) val_samples = val_dl.dataset[val_idx] xrans = data_util.get_win_xrans(xran, val_samples[1].shape[1], val_idx.tolist()) fig, ax = plot_util.init_fig(n_samples, ncols=4, sharex=True, subplot_hei=2, subplot_wid=5) lstm.eval() with torch.no_grad(): batch_len, seq_len, n_items = val_samples[1].shape pred_tr = lstm(val_samples[0].transpose(1, 0).to(args.device)) pred_tr = pred_tr.view([seq_len, batch_len, n_items]).transpose(1, 0) for lab, data in zip(["target", "pred"], [val_samples[1], pred_tr]): data = data.numpy() for n in range(n_samples): roi_n = np.random.choice(range(data.shape[-1])) sub_ax = plot_util.get_subax(ax, n) plot_util.plot_traces(sub_ax, xrans[n], data[n, :, roi_n], label=lab, xticks="auto") plot_util.set_ticks(sub_ax, "x", xran[0], xran[-1], n=7) sess_plot_util.plot_labels(ax, train_gabfr, plot_vals="exp", pre=roi_train_pre, post=train_post) fig.suptitle(f"Target vs predicted validation traces ({n_rois} ROIs)") fig.savefig(Path(dirname, f"{hyperstr}_traces"))
def reformat_args(args): """ reformat_args(args) Returns reformatted args for analyses, specifically - Sets stimulus parameters to "none" if they are irrelevant to the stimtype - Changes stimulus parameters from "both" to actual values - Modifies the session number parameter - Sets seed, though doesn't seed - Modifies analyses (if "all" or "all_" in parameter) - Sets latency parameters based on lat_method Adds the following args: - dend (str) : type of dendrites to use ("allen" or "extr") - omit_sess (str): sess to omit - omit_mice (str): mice to omit Required args: - args (Argument parser): parser with the following attributes: visflow_dir (str) : visual flow direction values to include (e.g., "right", "left" or "both") visflow_size (int or str): visual flow size values to include (e.g., 128, 256, "both") gabfr (int) : gabor frame value to start sequences at (e.g., 0, 1, 2, 3) gabk (int or str) : gabor kappa values to include (e.g., 4, 16 or "both") gab_ori (int or str) : gabor orientation values to include (e.g., 0, 45, 90, 135, 180, 225 or "all") mouse_ns (str) : mouse numbers or range (e.g., 1, "1,3", "1-3", "all") runtype (str) : runtype ("pilot" or "prod") sess_n (str) : session number range (e.g., "1-1", "all") stimtype (str) : stimulus to analyse (visflow or gabors) Returns: - args (Argument parser): input parser, with the following attributes modified: visflow_dir, visflow_size, gabfr, gabk, gab_ori, sess_n, mouse_ns, analyses, seed, lat_p_val_thr, lat_rel_std and the following attributes added: omit_sess, omit_mice, dend """ args = copy.deepcopy(args) if args.plane == "soma": args.dend = "allen" [args.visflow_dir, args.visflow_size, args.gabfr, args.gabk, args.gab_ori] = sess_gen_util.get_params(args.stimtype, args.visflow_dir, args.visflow_size, args.gabfr, args.gabk, args.gab_ori) if args.datatype == "run": args.fluor = "n/a" if args.plane == "soma": args.dend = "allen" args.omit_sess, args.omit_mice = sess_gen_util.all_omit( args.stimtype, args.runtype, args.visflow_dir, args.visflow_size, args.gabk) if "-" in str(args.sess_n): vals = str(args.sess_n).split("-") if len(vals) != 2: raise ValueError( "If args.sess_n is a range, must have format 1-3.") st = int(vals[0]) end = int(vals[1]) + 1 args.sess_n = list(range(st, end)) if args.lat_method == "ratio": args.lat_p_val_thr = None elif args.lat_method == "ttest": args.lat_rel_std = None # choose a seed if none is provided (i.e., args.seed=-1), but seed later args.seed = rand_util.seed_all(args.seed, "cpu", log_seed=False, seed_now=False) # collect mouse numbers from args.mouse_ns if "," in args.mouse_ns: args.mouse_ns = [int(n) for n in args.mouse_ns.split(",")] elif "-" in args.mouse_ns: vals = str(args.mouse_ns).split("-") if len(vals) != 2: raise ValueError( "If args.mouse_ns is a range, must have format 1-3.") st = int(vals[0]) end = int(vals[1]) + 1 args.mouse_ns = list(range(st, end)) elif args.mouse_ns not in ["all", "any"]: args.mouse_ns = int(args.mouse_ns) # collect analysis letters all_analyses = "".join(get_analysis_fcts().keys()) if "all" in args.analyses: if "_" in args.analyses: excl = args.analyses.split("_")[1] args.analyses, _ = gen_util.remove_lett(all_analyses, excl) else: args.analyses = all_analyses elif "_" in args.analyses: raise ValueError("Use '_' in args.analyses only with 'all'.") return args
def plot_violin_data(sub_ax, xs, all_data, palette=None, dashes=None, seed=None): """ plot_violin_data(sub_ax, xs, all_data) Plots violin data for each data group. Required args: - sub_ax (plt subplot): subplot - xs (list): x value for each data group - all_data (list): data for each data group Optional args: - palette (list) colors for each data group default: None - dashes (list): dash patterns for each data group default: None - seed (int): seed value to use. (-1 treated as None) default: None """ # seed for scatterplot rand_util.seed_all(seed, log_seed=False) # checks if len(xs) != len(all_data): raise ValueError("xs must have the same length as all_data.") if palette is not None and len(xs) != len(palette): raise ValueError( "palette, if provided, must have the same length as xs." ) if dashes is not None and len(xs) != len(dashes): raise ValueError( "dashes, if provided, must have the same length as xs." ) xs_arr = np.concatenate([ np.full_like(data, x) for x, data in zip(xs, all_data) ]) data_arr = np.concatenate(all_data) # add violins bplot = seaborn.violinplot( x=xs_arr, y=data_arr, inner=None, linewidth=3.5, color="white", ax=sub_ax ) # edit contours for c, collec in enumerate(bplot.collections): collec.set_edgecolor(plot_helper_fcts.NEARBLACK) if dashes is not None and dashes[c] is not None: collec.set_linestyle(plot_helper_fcts.VDASH) # add data dots seaborn.stripplot( x=xs_arr, y=data_arr, size=9, jitter=0.2, alpha=0.3, palette=palette, ax=sub_ax )